Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / luci-interpreter / src / kernels / If.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "kernels/If.h"
18 #include "kernels/Utils.h"
19
20 #include <cstring>
21
22 namespace luci_interpreter
23 {
24 namespace kernels
25 {
26
27 static std::vector<const Tensor *> joinInputs(const Tensor *cond,
28                                               const std::vector<const Tensor *> &inputs)
29 {
30   std::vector<const Tensor *> result{cond};
31   result.insert(result.cend(), inputs.cbegin(), inputs.cend());
32   return result;
33 }
34
35 If::If(const Tensor *cond, const std::vector<const Tensor *> &inputs, std::vector<Tensor *> outputs,
36        RuntimeGraph *then_graph, RuntimeGraph *else_graph)
37   : Kernel(joinInputs(cond, inputs), std::move(outputs)), _then_graph(then_graph),
38     _else_graph(else_graph)
39 {
40 }
41
42 void If::configure()
43 {
44   LUCI_INTERPRETER_CHECK(cond()->element_type() == DataType::BOOL);
45   LUCI_INTERPRETER_CHECK(cond()->shape().num_elements() == 1);
46
47   for (RuntimeGraph *graph : {_then_graph, _else_graph})
48   {
49     (void)graph;
50     LUCI_INTERPRETER_CHECK(graph->getInputTensors().size() == getInputTensors().size() - 1);
51     LUCI_INTERPRETER_CHECK(graph->getOutputTensors().size() == getOutputTensors().size());
52   }
53 }
54
55 void If::execute() const
56 {
57   const bool cond_value = cond()->data<bool>()[0];
58
59   RuntimeGraph *active_graph = cond_value ? _then_graph : _else_graph;
60   const auto &graph_inputs = active_graph->getInputTensors();
61   const auto &graph_outputs = active_graph->getOutputTensors();
62
63   // Copy kernel inputs to active graph inputs.
64   for (size_t i = 0; i < getInputTensors().size() - 1; ++i)
65   {
66     LUCI_INTERPRETER_CHECK(graph_inputs[i]->element_type() == input(i)->element_type());
67     graph_inputs[i]->resize(input(i)->shape());
68
69     const int32_t num_elements = input(i)->shape().num_elements();
70     const std::size_t element_size = getDataTypeSize(input(i)->element_type());
71     // TODO: Think about how allocate memory for output in main graph
72     active_graph->configureAllocations(graph_inputs[i]);
73     std::memcpy(graph_inputs[i]->data<void>(), input(i)->data<void>(), num_elements * element_size);
74   }
75
76   active_graph->execute();
77
78   // Copy graph outputs to kernel outputs.
79   for (size_t i = 0; i < getOutputTensors().size(); ++i)
80   {
81     LUCI_INTERPRETER_CHECK(graph_outputs[i]->element_type() == output(i)->element_type());
82     output(i)->resize(graph_outputs[i]->shape());
83     // TODO: Think about how allocate memory for output in main graph
84     active_graph->configureAllocations(output(i));
85
86     const int32_t num_elements = output(i)->shape().num_elements();
87     const std::size_t element_size = getDataTypeSize(output(i)->element_type());
88     std::memcpy(output(i)->data<void>(), graph_outputs[i]->data<void>(),
89                 num_elements * element_size);
90   }
91 }
92
93 } // namespace kernels
94 } // namespace luci_interpreter