* limitations under the License.
*/
-#include "kernels/While.h"
+#include "Builders.h"
#include "kernels/Utils.h"
#include <cstring>
namespace luci_interpreter
{
-namespace kernels
-{
-namespace
+void configure_kernel_CircleWhile(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
{
+ auto *main_runtime_graph = runtime_graph;
+
+ auto *runtime_module = runtime_graph->getRuntimeModule();
+
+ const auto *options = cur_op->builtin_options_as_WhileOptions();
+ const auto body_subgraph_index = options->body_subgraph_index();
+ const auto cond_subgraph_index = options->cond_subgraph_index();
+
+ auto *cond_runtime_graph = runtime_module->getRuntimeGraphAt(cond_subgraph_index);
+ auto *body_runtime_graph = runtime_module->getRuntimeGraphAt(body_subgraph_index);
+
+ body_runtime_graph->selectOwnSubgraph();
+ const auto body_input_size = body_runtime_graph->getNumOfInputTensors();
+ const auto body_output_size = body_runtime_graph->getNumOfOutputTensors();
+ LUCI_INTERPRETER_CHECK(body_input_size == cur_op->inputs()->size());
+ LUCI_INTERPRETER_CHECK(body_output_size == cur_op->outputs()->size());
+ LUCI_INTERPRETER_CHECK(body_output_size == cur_op->inputs()->size());
+ body_runtime_graph->invalidate();
+ body_runtime_graph->configure(false);
+
+ cond_runtime_graph->selectOwnSubgraph();
+ const auto cond_input_size = cond_runtime_graph->getNumOfInputTensors();
+ const auto cond_output_size = cond_runtime_graph->getNumOfOutputTensors();
+ LUCI_INTERPRETER_CHECK(cond_input_size == cur_op->inputs()->size());
+ LUCI_INTERPRETER_CHECK(cond_output_size == 1);
+ const circle::Tensor *cond_output_tensor = cond_runtime_graph->getOutputTensorByIndex(0);
+ LUCI_INTERPRETER_CHECK(Tensor::element_type(cond_output_tensor) == DataType::BOOL);
+ cond_runtime_graph->invalidate();
+ cond_runtime_graph->configure(false);
+
+ main_runtime_graph->selectOwnSubgraph();
+}
-void copy(const std::vector<const Tensor *> &src, const std::vector<Tensor *> &dst)
+void execute_kernel_CircleWhile(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
{
- for (size_t i = 0; i < src.size(); ++i)
- {
- LUCI_INTERPRETER_CHECK(dst[i]->element_type() == src[i]->element_type());
- dst[i]->resize(src[i]->shape());
+ auto *main_runtime_graph = runtime_graph;
+ auto *runtime_module = runtime_graph->getRuntimeModule();
- const int32_t num_elements = src[i]->shape().num_elements();
- const std::size_t element_size = getDataTypeSize(src[i]->element_type());
- std::memcpy(dst[i]->data<void>(), src[i]->data<void>(), num_elements * element_size);
- }
-}
+ const auto input_size = cur_op->inputs()->size();
-void copy(const std::vector<Tensor *> &src, const std::vector<Tensor *> &dst)
-{
- std::vector<const Tensor *> const_src;
- for (const auto &t : src)
- const_src.push_back(t);
- copy(const_src, dst);
-}
+ std::vector<uint8_t *> operation_inputs_data(input_size);
+ std::vector<uint8_t *> operation_outputs_data;
-// TODO: Think about how allocate memory for output in main graph
-void configureTensorsAllocations(const std::vector<Tensor *> &tensors, RuntimeGraph *run_graph)
-{
- for (auto tensor : tensors)
- run_graph->configureAllocations(tensor);
-}
+ std::vector<int32_t> input_sizes(input_size);
-} // namespace
+ bool is_inplace = runtime_graph->is_inplace_op(cur_op);
-While::While(std::vector<const Tensor *> inputs, std::vector<Tensor *> outputs,
- RuntimeGraph *cond_graph, RuntimeGraph *body_graph)
- : Kernel(std::move(inputs), std::move(outputs)), _cond_graph(cond_graph), _body_graph(body_graph)
-{
-}
+ for (int32_t i = 0; i < input_size; ++i)
+ {
+ const auto op_input_index = cur_op->inputs()->operator[](i);
+ const auto op_output_index = cur_op->outputs()->operator[](i);
+ assert(op_input_index != -1);
+ assert(op_output_index != -1);
+ const auto input = main_runtime_graph->getCircleTensorByIndex(op_input_index);
+ const auto output = main_runtime_graph->getCircleTensorByIndex(op_output_index);
+
+ input_sizes[i] = Tensor::num_elements(input) * size(Tensor::element_type(input));
+
+ auto *input_data = main_runtime_graph->getDataByTensor(input);
+
+ uint8_t *tensor_data = nullptr;
+ if (is_inplace)
+ {
+ if (input_data == nullptr)
+ {
+ tensor_data = new uint8_t[input_sizes[i]];
+ input_data = main_runtime_graph->getConstDataByTensor(input);
+ assert(input_data != nullptr);
+ std::memcpy(tensor_data, input_data, input_sizes[i]);
+ }
+ else
+ {
+ tensor_data = input_data;
+ }
+ }
+ else
+ {
+ if (input_data == nullptr)
+ input_data = main_runtime_graph->getConstDataByTensor(input);
+ assert(input_data != nullptr);
+ tensor_data = main_runtime_graph->getDataByTensor(output);
+ assert(tensor_data != nullptr);
+ std::memcpy(tensor_data, input_data, input_sizes[i]);
+ }
+ assert(tensor_data != nullptr);
+
+ operation_inputs_data[i] = tensor_data;
+ }
-void While::configure()
-{
- LUCI_INTERPRETER_CHECK(_body_graph->getInputTensors().size() == getInputTensors().size());
- LUCI_INTERPRETER_CHECK(_body_graph->getOutputTensors().size() == getOutputTensors().size());
- LUCI_INTERPRETER_CHECK(_body_graph->getOutputTensors().size() == getInputTensors().size());
+ const auto *options = cur_op->builtin_options_as_WhileOptions();
+ const auto body_subgraph_index = options->body_subgraph_index();
+ const auto cond_subgraph_index = options->cond_subgraph_index();
- LUCI_INTERPRETER_CHECK(_cond_graph->getInputTensors().size() == getInputTensors().size());
+ auto *cond_runtime_graph = runtime_module->getRuntimeGraphAt(cond_subgraph_index);
+ auto *body_runtime_graph = runtime_module->getRuntimeGraphAt(body_subgraph_index);
- const auto &cond_outputs = _cond_graph->getOutputTensors();
- LUCI_INTERPRETER_CHECK(cond_outputs.size() == 1)
- LUCI_INTERPRETER_CHECK(cond_outputs[0]->element_type() == DataType::BOOL);
-}
+ do
+ {
+ cond_runtime_graph->selectOwnSubgraph();
-/**
- * @note Dynamic shape such as {1, 0, 8} may fail in tensor->data()
- */
-void While::execute() const
-{
- const auto &cond_inputs = _cond_graph->getInputTensors();
- const auto &cond_outputs = _cond_graph->getOutputTensors();
+ for (int32_t i = 0; i < input_size; ++i)
+ cond_runtime_graph->configureGraphInput(i, operation_inputs_data[i]);
- configureTensorsAllocations(cond_inputs, _cond_graph);
+ cond_runtime_graph->execute();
- copy(getInputTensors(), cond_inputs);
+ bool cond_value = (cond_runtime_graph->getOutputDataByIndex(0))[0];
+ if (!cond_value)
+ break;
- const auto &body_inputs = _body_graph->getInputTensors();
- const auto &body_outputs = _body_graph->getOutputTensors();
+ body_runtime_graph->selectOwnSubgraph();
+ for (int32_t i = 0; i < input_size; ++i)
+ body_runtime_graph->configureGraphInput(i, operation_inputs_data[i]);
- configureTensorsAllocations(body_inputs, _body_graph);
+ body_runtime_graph->execute();
- while (true)
- {
- _cond_graph->execute();
+ for (int32_t i = 0; i < input_size; ++i)
+ {
+ auto cur_output_body_data = body_runtime_graph->getOutputDataByIndex(i);
+ if (cur_output_body_data == nullptr)
+ continue;
+ std::memcpy(operation_inputs_data[i], cur_output_body_data, input_sizes[i]);
+ }
+ } while (true);
- bool cond_value = cond_outputs[0]->data<bool>()[0];
- if (!cond_value)
- break;
+ cond_runtime_graph->resetOutputTensorsData();
+ cond_runtime_graph->clearTensors();
- copy(cond_inputs, body_inputs);
+ body_runtime_graph->selectOwnSubgraph();
+ body_runtime_graph->resetOutputTensorsData();
+ body_runtime_graph->clearTensors();
- _body_graph->execute();
+ main_runtime_graph->selectOwnSubgraph();
- copy(body_outputs, cond_inputs);
+ if (is_inplace)
+ {
+ for (int32_t i = 0; i < input_size; ++i)
+ {
+ const auto op_input_index = cur_op->inputs()->operator[](i);
+ const auto op_output_index = cur_op->outputs()->operator[](i);
+ assert(op_input_index != -1);
+ assert(op_output_index != -1);
+ const auto input = main_runtime_graph->getCircleTensorByIndex(op_input_index);
+ const auto output = main_runtime_graph->getCircleTensorByIndex(op_output_index);
+
+ if (main_runtime_graph->getDataByTensor(input))
+ {
+ main_runtime_graph->makeInplaceOperation(input, output);
+ }
+ else
+ {
+ main_runtime_graph->setDataToTensor(output, operation_inputs_data[i]);
+ }
+ }
}
-
- copy(cond_inputs, getOutputTensors());
}
-} // namespace kernels
} // namespace luci_interpreter