Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / Sub.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 #include "Builders.h"
18 #include "kernels/Utils.h"
19
20 #include "kernels/BinaryOpCommon.h"
21
22 #include "PALSub.h"
23
24 namespace luci_interpreter
25 {
26
27 void configure_kernel_CircleSub(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
28 {
29   kernels::TISOKernel kernel(cur_op, runtime_graph);
30
31   LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input1()) ==
32                          Tensor::element_type(kernel.input2()));
33   LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input1()) ==
34                          Tensor::element_type(kernel.input2()));
35 #ifndef DIS_QUANT
36   if (Tensor::element_type(kernel.input1()) == DataType::S16)
37   {
38     LUCI_INTERPRETER_CHECK(Tensor::zero_points(kernel.input1()).size() == 1 &&
39                            Tensor::zero_points(kernel.input2()).size() == 1);
40     LUCI_INTERPRETER_CHECK(Tensor::zero_point(kernel.input1()) == 0 &&
41                            Tensor::zero_point(kernel.input2()) == 0 &&
42                            Tensor::zero_point(kernel.output()) == 0);
43   }
44 #endif // DIS_QUANT
45 }
46
47 void execute_kernel_CircleSub(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
48 {
49   kernels::TISOKernel kernel(cur_op, runtime_graph);
50
51   const auto *options = cur_op->builtin_options_as_SubOptions();
52
53   luci_interpreter::RuntimeShape input_shape1 =
54     kernels::getTensorRuntimeShape(kernel.input1(), runtime_graph);
55   luci_interpreter::RuntimeShape input_shape2 =
56     kernels::getTensorRuntimeShape(kernel.input2(), runtime_graph);
57
58   bool is_inplace = runtime_graph->is_inplace_op(cur_op);
59
60   switch (Tensor::element_type(kernel.input1()))
61   {
62 #ifndef DIS_FLOAT
63     case DataType::FLOAT32:
64     {
65       auto tiso_func = luci_interpreter_pal::Sub<float>;
66
67       auto broadcast_tiso_func = luci_interpreter_pal::BroadcastSub4DSlow<float>;
68       if (is_inplace)
69       {
70         kernels::evalTISOInplaceKernel<float>(tiso_func, broadcast_tiso_func, &kernel, options,
71                                               std::move(input_shape1), std::move(input_shape2));
72       }
73       else
74       {
75         kernels::TISOData kernel_data = kernel.readData();
76         kernels::evalTISOKernel<float>(tiso_func, broadcast_tiso_func, &kernel, &kernel_data,
77                                        options, std::move(input_shape1), std::move(input_shape2));
78       }
79     }
80     break;
81 #endif // DIS_FLOAT
82     case DataType::S64:
83     {
84       auto tiso_func = luci_interpreter_pal::Sub<int64_t>;
85
86       auto broadcast_tiso_func = luci_interpreter_pal::BroadcastSub4DSlow<int64_t>;
87
88       if (is_inplace)
89       {
90         kernels::evalTISOInplaceKernel<int64_t>(tiso_func, broadcast_tiso_func, &kernel, options,
91                                                 std::move(input_shape1), std::move(input_shape2));
92       }
93       else
94       {
95         kernels::TISOData kernel_data = kernel.readData();
96         kernels::evalTISOKernel<int64_t>(tiso_func, broadcast_tiso_func, &kernel, &kernel_data,
97                                          options, std::move(input_shape1), std::move(input_shape2));
98       }
99     }
100     break;
101     case DataType::S32:
102     {
103       auto tiso_func = luci_interpreter_pal::Sub<int32_t>;
104
105       auto broadcast_tiso_func = luci_interpreter_pal::BroadcastSub4DSlow<int32_t>;
106
107       if (is_inplace)
108       {
109         kernels::evalTISOInplaceKernel<int32_t>(tiso_func, broadcast_tiso_func, &kernel, options,
110                                                 std::move(input_shape1), std::move(input_shape2));
111       }
112       else
113       {
114         kernels::TISOData kernel_data = kernel.readData();
115         kernels::evalTISOKernel<int32_t>(tiso_func, broadcast_tiso_func, &kernel, &kernel_data,
116                                          options, std::move(input_shape1), std::move(input_shape2));
117       }
118     }
119     break;
120 // TODO: fix it
121 #if 0
122 #ifndef DIS_QUANT
123     case DataType::U8:
124     {
125       auto tiso_func = [](const tflite::ArithmeticParams &params,
126                           const tflite::RuntimeShape &input1_shape, const uint8_t *input1_data,
127                           const tflite::RuntimeShape &input2_shape, const uint8_t *input2_data,
128                           const tflite::RuntimeShape &output_shape, uint8_t *output_data) {
129         tflite::reference_ops::Sub(params, input1_shape, input1_data, input2_shape, input2_data,
130                                    output_shape, output_data);
131       };
132       auto broadcast_tiso_func =
133         [](const tflite::ArithmeticParams &params, const tflite::RuntimeShape &input1_shape,
134            const uint8_t *input1_data, const tflite::RuntimeShape &input2_shape,
135            const uint8_t *input2_data, const tflite::RuntimeShape &output_shape,
136            uint8_t *output_data) {
137           tflite::reference_ops::BroadcastSubSlow(params, input1_shape, input1_data, input2_shape,
138                                                   input2_data, output_shape, output_data);
139         };
140       if (is_inplace)
141       {
142         kernels::evalTISOInplaceQuantizedKernel<uint8_t>(tiso_func, broadcast_tiso_func, &kernel,
143                                                          options);
144       }
145       else
146       {
147         kernels::TISOData kernel_data = kernel.readData();
148         kernels::evalTISOQuantizedKernel<uint8_t>(tiso_func, broadcast_tiso_func, &kernel,
149                                                   &kernel_data, options);
150       }
151     }
152     break;
153 #endif // DIS_QUANT
154 #endif // 0
155     default:
156       assert(false && "Unsupported type.");
157   }
158 }
159
160 } // namespace luci_interpreter