/*
* Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* limitations under the License.
*/
-#include "kernels/Equal.h"
+#include "Builders.h"
#include "kernels/Utils.h"
+#include "TISOKernel.h"
-#include <tensorflow/lite/kernels/internal/reference/comparisons.h>
+#include "PALComparisons.h"
namespace luci_interpreter
{
-namespace kernels
+namespace
{
-
-Equal::Equal(const Tensor *x, const Tensor *y, Tensor *output) : Kernel({x, y}, {output}) {}
-
-void Equal::configure()
+// TODO: reduce code duplication with less
+template <typename T>
+void evalGeneric(const circle::Tensor *x, const circle::Tensor *y, const circle::Tensor *output,
+ BaseRuntimeGraph *runtime_graph)
{
- LUCI_INTERPRETER_CHECK(x()->element_type() == y()->element_type());
- LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::BOOL);
+ auto x_data = kernels::getTensorData<T>(runtime_graph->getDataByTensor(x));
+ if (x_data == nullptr)
+ x_data = kernels::getTensorData<T>(runtime_graph->getConstDataByTensor(x));
- if (x()->element_type() == DataType::U8)
- {
- quantizeMultiplierSmallerThanOneExp(x()->scale(), &_x_multiplier, &_x_shift);
- quantizeMultiplierSmallerThanOneExp(y()->scale(), &_y_multiplier, &_y_shift);
- }
- // TODO: enable it only if kernel with dynamic shapes
- output()->resize(calculateShapeForBroadcast(x()->shape(), y()->shape()));
-}
+ assert(x_data != nullptr);
-void Equal::execute() const
-{
- switch (x()->element_type())
- {
- case DataType::FLOAT32:
- evalFloat();
- break;
- case DataType::S64:
- evalInteger<int64_t>();
- break;
- case DataType::S32:
- evalInteger<int32_t>();
- break;
- case DataType::U8:
- evalQuantized();
- break;
- default:
- assert(false && "Unsupported type.");
- }
-}
+ auto y_data = kernels::getTensorData<T>(runtime_graph->getDataByTensor(y));
+ if (y_data == nullptr)
+ y_data = kernels::getTensorData<T>(runtime_graph->getConstDataByTensor(y));
-void Equal::evalFloat() const
-{
- const auto x_data = getTensorData<float>(x());
- const auto y_data = getTensorData<float>(y());
- auto output_data = getTensorData<bool>(output());
+ assert(y_data != nullptr);
+
+ auto output_data = kernels::getTensorData<bool>(runtime_graph->getDataByTensor(output));
- tflite::ComparisonParams op_params;
- op_params.is_broadcast = x()->shape() != y()->shape();
+ luci_interpreter_pal::ComparisonParams op_params;
+ op_params.is_broadcast = Tensor::num_elements(x) != Tensor::num_elements(y);
if (op_params.is_broadcast)
{
- tflite::reference_ops::Broadcast4DSlowEqual(op_params, getTensorShape(x()), x_data,
- getTensorShape(y()), y_data,
- getTensorShape(output()), output_data);
+ luci_interpreter_pal::BroadcastComparison4DSlowNoScaling<T>(
+ op_params, kernels::getTensorShape(x), x_data, kernels::getTensorShape(y), y_data,
+ kernels::getTensorShape(output), output_data, luci_interpreter_pal::EqualFn);
}
else
{
- tflite::reference_ops::Equal(op_params, getTensorShape(x()), x_data, getTensorShape(y()),
- y_data, getTensorShape(output()), output_data);
+ const int64_t flat_size = kernels::getTensorShape(x).flatSize();
+ luci_interpreter_pal::ComparisonNoScaling<T>(flat_size, x_data, y_data, output_data,
+ luci_interpreter_pal::EqualFn);
}
}
-template <typename T> void Equal::evalInteger() const
-{
- const auto x_data = getTensorData<T>(x());
- const auto y_data = getTensorData<T>(y());
- auto output_data = getTensorData<bool>(output());
+} // namespace
- tflite::ComparisonParams op_params;
- op_params.is_broadcast = x()->shape() != y()->shape();
+void configure_kernel_CircleEqual(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
+{
+ kernels::TISOKernel kernel(cur_op, runtime_graph);
- if (op_params.is_broadcast)
- {
- tflite::reference_ops::Broadcast4DSlowEqualNoScaling(op_params, getTensorShape(x()), x_data,
- getTensorShape(y()), y_data,
- getTensorShape(output()), output_data);
- }
- else
- {
- tflite::reference_ops::EqualNoScaling(op_params, getTensorShape(x()), x_data,
- getTensorShape(y()), y_data, getTensorShape(output()),
- output_data);
- }
+ LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input1()) ==
+ Tensor::element_type(kernel.input2()));
+ LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.output()) == DataType::BOOL);
}
-void Equal::evalQuantized() const
+void execute_kernel_CircleEqual(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
{
- const auto x_data = getTensorData<uint8_t>(x());
- const auto y_data = getTensorData<uint8_t>(y());
- auto output_data = getTensorData<bool>(output());
-
- tflite::ComparisonParams op_params;
- op_params.left_shift = 8;
- op_params.input1_offset = -x()->zero_point(); // Note the '-'
- op_params.input1_shift = _x_shift;
- op_params.input1_multiplier = _x_multiplier;
- op_params.input2_offset = -y()->zero_point(); // Note the '-'
- op_params.input2_shift = _y_shift;
- op_params.input2_multiplier = _y_multiplier;
- op_params.is_broadcast = x()->shape() != y()->shape();
+ kernels::TISOKernel kernel(cur_op, runtime_graph);
- if (op_params.is_broadcast)
- {
- tflite::reference_ops::Broadcast4DSlowEqualWithScaling(op_params, getTensorShape(x()), x_data,
- getTensorShape(y()), y_data,
- getTensorShape(output()), output_data);
- }
- else
+ switch (Tensor::element_type(kernel.input1()))
{
- tflite::reference_ops::EqualWithScaling(op_params, getTensorShape(x()), x_data,
- getTensorShape(y()), y_data, getTensorShape(output()),
- output_data);
+ case DataType::S64:
+ evalGeneric<int64_t>(kernel.input1(), kernel.input2(), kernel.output(), runtime_graph);
+ break;
+ case DataType::S32:
+ evalGeneric<int32_t>(kernel.input1(), kernel.input2(), kernel.output(), runtime_graph);
+ break;
+#ifndef DIS_FLOAT
+ case DataType::FLOAT32:
+ evalGeneric<float>(kernel.input1(), kernel.input2(), kernel.output(), runtime_graph);
+ break;
+#endif // DIS_FLOAT
+ default:
+ assert(false && "Unsupported type.");
}
}
-} // namespace kernels
} // namespace luci_interpreter