2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #include "kernels/Utils.h"
20 #include "kernels/BinaryOpCommon.h"
24 namespace luci_interpreter
27 void configure_kernel_CircleSub(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
29 kernels::TISOKernel kernel(cur_op, runtime_graph);
31 LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input1()) ==
32 Tensor::element_type(kernel.input2()));
33 LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input1()) ==
34 Tensor::element_type(kernel.input2()));
36 if (Tensor::element_type(kernel.input1()) == DataType::S16)
38 LUCI_INTERPRETER_CHECK(Tensor::zero_points(kernel.input1()).size() == 1 &&
39 Tensor::zero_points(kernel.input2()).size() == 1);
40 LUCI_INTERPRETER_CHECK(Tensor::zero_point(kernel.input1()) == 0 &&
41 Tensor::zero_point(kernel.input2()) == 0 &&
42 Tensor::zero_point(kernel.output()) == 0);
47 void execute_kernel_CircleSub(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
49 kernels::TISOKernel kernel(cur_op, runtime_graph);
51 const auto *options = cur_op->builtin_options_as_SubOptions();
53 luci_interpreter::RuntimeShape input_shape1 =
54 kernels::getTensorRuntimeShape(kernel.input1(), runtime_graph);
55 luci_interpreter::RuntimeShape input_shape2 =
56 kernels::getTensorRuntimeShape(kernel.input2(), runtime_graph);
58 bool is_inplace = runtime_graph->is_inplace_op(cur_op);
60 switch (Tensor::element_type(kernel.input1()))
63 case DataType::FLOAT32:
65 auto tiso_func = luci_interpreter_pal::Sub<float>;
67 auto broadcast_tiso_func = luci_interpreter_pal::BroadcastSub4DSlow<float>;
70 kernels::evalTISOInplaceKernel<float>(tiso_func, broadcast_tiso_func, &kernel, options,
71 std::move(input_shape1), std::move(input_shape2));
75 kernels::TISOData kernel_data = kernel.readData();
76 kernels::evalTISOKernel<float>(tiso_func, broadcast_tiso_func, &kernel, &kernel_data,
77 options, std::move(input_shape1), std::move(input_shape2));
84 auto tiso_func = luci_interpreter_pal::Sub<int64_t>;
86 auto broadcast_tiso_func = luci_interpreter_pal::BroadcastSub4DSlow<int64_t>;
90 kernels::evalTISOInplaceKernel<int64_t>(tiso_func, broadcast_tiso_func, &kernel, options,
91 std::move(input_shape1), std::move(input_shape2));
95 kernels::TISOData kernel_data = kernel.readData();
96 kernels::evalTISOKernel<int64_t>(tiso_func, broadcast_tiso_func, &kernel, &kernel_data,
97 options, std::move(input_shape1), std::move(input_shape2));
103 auto tiso_func = luci_interpreter_pal::Sub<int32_t>;
105 auto broadcast_tiso_func = luci_interpreter_pal::BroadcastSub4DSlow<int32_t>;
109 kernels::evalTISOInplaceKernel<int32_t>(tiso_func, broadcast_tiso_func, &kernel, options,
110 std::move(input_shape1), std::move(input_shape2));
114 kernels::TISOData kernel_data = kernel.readData();
115 kernels::evalTISOKernel<int32_t>(tiso_func, broadcast_tiso_func, &kernel, &kernel_data,
116 options, std::move(input_shape1), std::move(input_shape2));
125 auto tiso_func = [](const tflite::ArithmeticParams ¶ms,
126 const tflite::RuntimeShape &input1_shape, const uint8_t *input1_data,
127 const tflite::RuntimeShape &input2_shape, const uint8_t *input2_data,
128 const tflite::RuntimeShape &output_shape, uint8_t *output_data) {
129 tflite::reference_ops::Sub(params, input1_shape, input1_data, input2_shape, input2_data,
130 output_shape, output_data);
132 auto broadcast_tiso_func =
133 [](const tflite::ArithmeticParams ¶ms, const tflite::RuntimeShape &input1_shape,
134 const uint8_t *input1_data, const tflite::RuntimeShape &input2_shape,
135 const uint8_t *input2_data, const tflite::RuntimeShape &output_shape,
136 uint8_t *output_data) {
137 tflite::reference_ops::BroadcastSubSlow(params, input1_shape, input1_data, input2_shape,
138 input2_data, output_shape, output_data);
142 kernels::evalTISOInplaceQuantizedKernel<uint8_t>(tiso_func, broadcast_tiso_func, &kernel,
147 kernels::TISOData kernel_data = kernel.readData();
148 kernels::evalTISOQuantizedKernel<uint8_t>(tiso_func, broadcast_tiso_func, &kernel,
149 &kernel_data, options);
156 assert(false && "Unsupported type.");
160 } // namespace luci_interpreter