Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / MISOKernel.h
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef LUCI_INTERPRETER_KERNELS_MISO_KERNEL_H
18 #define LUCI_INTERPRETER_KERNELS_MISO_KERNEL_H
19
20 #include "Builders.h"
21
22 namespace luci_interpreter
23 {
24 namespace kernels
25 {
26
27 // Multiple input single output kernel
28 class MISOKernel
29 {
30 public:
31   MISOKernel(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
32   {
33     const auto input1_index = cur_op->inputs()->operator[](0);
34     const auto input2_index = cur_op->inputs()->operator[](1);
35     const auto input3_index = cur_op->inputs()->operator[](2);
36     const auto input4_index =
37       cur_op->inputs()->size() == 4 ? cur_op->inputs()->operator[](3) : -1; // optional
38
39     const auto output_index = cur_op->outputs()->operator[](0);
40
41     assert(input1_index != -1);
42     assert(input2_index != -1);
43     assert(input3_index != -1);
44
45     assert(output_index != -1);
46
47     _input1_tensor = runtime_graph->getCircleTensorByIndex(input1_index);
48     _input2_tensor = runtime_graph->getCircleTensorByIndex(input2_index);
49     _input3_tensor = runtime_graph->getCircleTensorByIndex(input3_index);
50     _output_tensor = runtime_graph->getCircleTensorByIndex(output_index);
51
52     // optional
53     if (input4_index != -1)
54       _input4_tensor = runtime_graph->getCircleTensorByIndex(input4_index);
55     else
56       _input4_tensor = nullptr;
57
58     assert(_input1_tensor != nullptr);
59     assert(_input2_tensor != nullptr);
60     assert(_input3_tensor != nullptr);
61     assert(_output_tensor != nullptr);
62   }
63
64   const circle::Tensor *input1() const { return _input1_tensor; }
65   const circle::Tensor *input2() const { return _input2_tensor; }
66   const circle::Tensor *input3() const { return _input3_tensor; }
67
68   const circle::Tensor *input4() const
69   {
70     assert(_input4_tensor != nullptr);
71     return _input4_tensor;
72   }
73
74   const circle::Tensor *output() const { return _output_tensor; }
75
76 private:
77   const circle::Tensor *_input1_tensor;
78   const circle::Tensor *_input2_tensor;
79   const circle::Tensor *_input3_tensor;
80   const circle::Tensor *_input4_tensor; // optional
81   const circle::Tensor *_output_tensor;
82 };
83
84 } // namespace kernels
85 } // namespace luci_interpreter
86
87 #endif // LUCI_INTERPRETER_KERNELS_MISO_KERNEL_H