* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
-#include "kernels/Sub.h"
+#include "Builders.h"
#include "kernels/Utils.h"
-#include "PALSub.h"
+#include "kernels/BinaryOpCommon.h"
-#include <tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h>
+#include "PALSub.h"
namespace luci_interpreter
{
-namespace kernels
-{
-Sub::Sub(const Tensor *input1, const Tensor *input2, Tensor *output, const SubParams ¶ms)
- : KernelWithParams<SubParams>({input1, input2}, {output}, params)
+void configure_kernel_CircleSub(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
{
+ kernels::TISOKernel kernel(cur_op, runtime_graph);
+
+ LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input1()) ==
+ Tensor::element_type(kernel.input2()));
+ LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input1()) ==
+ Tensor::element_type(kernel.input2()));
+#ifndef DIS_QUANT
+ if (Tensor::element_type(kernel.input1()) == DataType::S16)
+ {
+ LUCI_INTERPRETER_CHECK(Tensor::zero_points(kernel.input1()).size() == 1 &&
+ Tensor::zero_points(kernel.input2()).size() == 1);
+ LUCI_INTERPRETER_CHECK(Tensor::zero_point(kernel.input1()) == 0 &&
+ Tensor::zero_point(kernel.input2()) == 0 &&
+ Tensor::zero_point(kernel.output()) == 0);
+ }
+#endif // DIS_QUANT
}
-void Sub::configure()
+void execute_kernel_CircleSub(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
{
- LUCI_INTERPRETER_CHECK(!(input1()->element_type() != input2()->element_type()))
- LUCI_INTERPRETER_CHECK(!(input1()->element_type() != output()->element_type()))
- // TODO: enable it only if kernel with dynamic shapes
- output()->resize(calculateShapeForBroadcast(input1()->shape(), input2()->shape()));
-}
+ kernels::TISOKernel kernel(cur_op, runtime_graph);
-void Sub::execute() const
-{
- switch (input1()->element_type())
+ const auto *options = cur_op->builtin_options_as_SubOptions();
+
+ luci_interpreter::RuntimeShape input_shape1 =
+ kernels::getTensorRuntimeShape(kernel.input1(), runtime_graph);
+ luci_interpreter::RuntimeShape input_shape2 =
+ kernels::getTensorRuntimeShape(kernel.input2(), runtime_graph);
+
+ bool is_inplace = runtime_graph->is_inplace_op(cur_op);
+
+ switch (Tensor::element_type(kernel.input1()))
{
+#ifndef DIS_FLOAT
case DataType::FLOAT32:
- evalFloat();
- break;
+ {
+ auto tiso_func = luci_interpreter_pal::Sub<float>;
+
+ auto broadcast_tiso_func = luci_interpreter_pal::BroadcastSub4DSlow<float>;
+ if (is_inplace)
+ {
+ kernels::evalTISOInplaceKernel<float>(tiso_func, broadcast_tiso_func, &kernel, options,
+ std::move(input_shape1), std::move(input_shape2));
+ }
+ else
+ {
+ kernels::TISOData kernel_data = kernel.readData();
+ kernels::evalTISOKernel<float>(tiso_func, broadcast_tiso_func, &kernel, &kernel_data,
+ options, std::move(input_shape1), std::move(input_shape2));
+ }
+ }
+ break;
+#endif // DIS_FLOAT
case DataType::S64:
- evalInteger<int64_t>();
- break;
+ {
+ auto tiso_func = luci_interpreter_pal::Sub<int64_t>;
+
+ auto broadcast_tiso_func = luci_interpreter_pal::BroadcastSub4DSlow<int64_t>;
+
+ if (is_inplace)
+ {
+ kernels::evalTISOInplaceKernel<int64_t>(tiso_func, broadcast_tiso_func, &kernel, options,
+ std::move(input_shape1), std::move(input_shape2));
+ }
+ else
+ {
+ kernels::TISOData kernel_data = kernel.readData();
+ kernels::evalTISOKernel<int64_t>(tiso_func, broadcast_tiso_func, &kernel, &kernel_data,
+ options, std::move(input_shape1), std::move(input_shape2));
+ }
+ }
+ break;
case DataType::S32:
- evalInteger<int32_t>();
- break;
+ {
+ auto tiso_func = luci_interpreter_pal::Sub<int32_t>;
+
+ auto broadcast_tiso_func = luci_interpreter_pal::BroadcastSub4DSlow<int32_t>;
+
+ if (is_inplace)
+ {
+ kernels::evalTISOInplaceKernel<int32_t>(tiso_func, broadcast_tiso_func, &kernel, options,
+ std::move(input_shape1), std::move(input_shape2));
+ }
+ else
+ {
+ kernels::TISOData kernel_data = kernel.readData();
+ kernels::evalTISOKernel<int32_t>(tiso_func, broadcast_tiso_func, &kernel, &kernel_data,
+ options, std::move(input_shape1), std::move(input_shape2));
+ }
+ }
+ break;
+// TODO: fix it
+#if 0
+#ifndef DIS_QUANT
case DataType::U8:
- evalQuantized();
- break;
+ {
+ auto tiso_func = [](const tflite::ArithmeticParams ¶ms,
+ const tflite::RuntimeShape &input1_shape, const uint8_t *input1_data,
+ const tflite::RuntimeShape &input2_shape, const uint8_t *input2_data,
+ const tflite::RuntimeShape &output_shape, uint8_t *output_data) {
+ tflite::reference_ops::Sub(params, input1_shape, input1_data, input2_shape, input2_data,
+ output_shape, output_data);
+ };
+ auto broadcast_tiso_func =
+ [](const tflite::ArithmeticParams ¶ms, const tflite::RuntimeShape &input1_shape,
+ const uint8_t *input1_data, const tflite::RuntimeShape &input2_shape,
+ const uint8_t *input2_data, const tflite::RuntimeShape &output_shape,
+ uint8_t *output_data) {
+ tflite::reference_ops::BroadcastSubSlow(params, input1_shape, input1_data, input2_shape,
+ input2_data, output_shape, output_data);
+ };
+ if (is_inplace)
+ {
+ kernels::evalTISOInplaceQuantizedKernel<uint8_t>(tiso_func, broadcast_tiso_func, &kernel,
+ options);
+ }
+ else
+ {
+ kernels::TISOData kernel_data = kernel.readData();
+ kernels::evalTISOQuantizedKernel<uint8_t>(tiso_func, broadcast_tiso_func, &kernel,
+ &kernel_data, options);
+ }
+ }
+ break;
+#endif // DIS_QUANT
+#endif // 0
default:
assert(false && "Unsupported type.");
}
}
-void Sub::evalFloat() const
-{
- tflite::ArithmeticParams params{};
- fillArithmeticActivationRange<float>(params, _params.activation);
-
- const bool need_broadcast = tflite::reference_ops::ProcessBroadcastShapes(
- getTensorShape(input1()), getTensorShape(input2()), ¶ms);
-
- if (need_broadcast)
- {
- tflite::reference_ops::BroadcastSubSlow(
- params, getTensorShape(input1()), getTensorData<float>(input1()), getTensorShape(input2()),
- getTensorData<float>(input2()), getTensorShape(output()), getTensorData<float>(output()));
- }
- else
- {
- luci_interpreter_pal::Sub(params, getTensorShape(input1()), getTensorData<float>(input1()),
- getTensorShape(input2()), getTensorData<float>(input2()),
- getTensorShape(output()), getTensorData<float>(output()));
- }
-}
-
-template <typename T> void Sub::evalInteger() const
-{
- tflite::ArithmeticParams params{};
- fillArithmeticActivationRange<T>(params, _params.activation);
-
- const bool need_broadcast = tflite::reference_ops::ProcessBroadcastShapes(
- getTensorShape(input1()), getTensorShape(input2()), ¶ms);
-
- if (need_broadcast)
- {
- tflite::reference_ops::BroadcastSubSlow(
- params, getTensorShape(input1()), getTensorData<T>(input1()), getTensorShape(input2()),
- getTensorData<T>(input2()), getTensorShape(output()), getTensorData<T>(output()));
- }
- else
- {
- tflite::reference_ops::Sub(params, getTensorShape(input1()), getTensorData<T>(input1()),
- getTensorShape(input2()), getTensorData<T>(input2()),
- getTensorShape(output()), getTensorData<T>(output()));
- }
-}
-
-void Sub::evalQuantized() const
-{
- const auto input1_scale = static_cast<double>(input1()->scale());
- const auto input2_scale = static_cast<double>(input2()->scale());
- const auto output_scale = static_cast<double>(output()->scale());
-
- const int left_shift = 20;
- const double twice_max_input_scale = 2 * std::max(input1_scale, input2_scale);
- const double real_input1_multiplier = input1_scale / twice_max_input_scale;
- const double real_input2_multiplier = input2_scale / twice_max_input_scale;
- const double real_output_multiplier = twice_max_input_scale / ((1 << left_shift) * output_scale);
-
- int32_t input1_multiplier{}, input2_multiplier{}, output_multiplier{};
- int input1_shift{}, input2_shift{}, output_shift{};
- quantizeMultiplierSmallerThanOneExp(real_input1_multiplier, &input1_multiplier, &input1_shift);
- quantizeMultiplierSmallerThanOneExp(real_input2_multiplier, &input2_multiplier, &input2_shift);
- quantizeMultiplierSmallerThanOneExp(real_output_multiplier, &output_multiplier, &output_shift);
-
- int32_t activation_min{};
- int32_t activation_max{};
- calculateActivationRangeQuantized(_params.activation, output(), &activation_min, &activation_max);
-
- tflite::ArithmeticParams params{};
- params.left_shift = left_shift;
- // The kernel expects inputs' zero points to be negated.
- params.input1_offset = -input1()->zero_point(); // Note the '-'.
- params.input1_multiplier = input1_multiplier;
- params.input1_shift = input1_shift;
- params.input2_offset = -input2()->zero_point(); // Note the '-'.
- params.input2_multiplier = input2_multiplier;
- params.input2_shift = input2_shift;
- params.output_offset = output()->zero_point();
- params.output_multiplier = output_multiplier;
- params.output_shift = output_shift;
- params.quantized_activation_min = activation_min;
- params.quantized_activation_max = activation_max;
-
- const bool need_broadcast = tflite::reference_ops::ProcessBroadcastShapes(
- getTensorShape(input1()), getTensorShape(input2()), ¶ms);
-
- if (need_broadcast)
- {
- tflite::reference_ops::BroadcastSubSlow(
- params, getTensorShape(input1()), getTensorData<uint8_t>(input1()), getTensorShape(input2()),
- getTensorData<uint8_t>(input2()), getTensorShape(output()), getTensorData<uint8_t>(output()));
- }
- else
- {
- tflite::reference_ops::Sub(params, getTensorShape(input1()), getTensorData<uint8_t>(input1()),
- getTensorShape(input2()), getTensorData<uint8_t>(input2()),
- getTensorShape(output()), getTensorData<uint8_t>(output()));
- }
-}
-
-} // namespace kernels
} // namespace luci_interpreter