Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / Sub.cpp
index 7b02c1e..5eaed32 100644 (file)
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-
-#include "kernels/Sub.h"
+#include "Builders.h"
 #include "kernels/Utils.h"
 
-#include "PALSub.h"
+#include "kernels/BinaryOpCommon.h"
 
-#include <tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h>
+#include "PALSub.h"
 
 namespace luci_interpreter
 {
-namespace kernels
-{
 
-Sub::Sub(const Tensor *input1, const Tensor *input2, Tensor *output, const SubParams &params)
-  : KernelWithParams<SubParams>({input1, input2}, {output}, params)
+void configure_kernel_CircleSub(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
 {
+  kernels::TISOKernel kernel(cur_op, runtime_graph);
+
+  LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input1()) ==
+                         Tensor::element_type(kernel.input2()));
+  LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input1()) ==
+                         Tensor::element_type(kernel.input2()));
+#ifndef DIS_QUANT
+  if (Tensor::element_type(kernel.input1()) == DataType::S16)
+  {
+    LUCI_INTERPRETER_CHECK(Tensor::zero_points(kernel.input1()).size() == 1 &&
+                           Tensor::zero_points(kernel.input2()).size() == 1);
+    LUCI_INTERPRETER_CHECK(Tensor::zero_point(kernel.input1()) == 0 &&
+                           Tensor::zero_point(kernel.input2()) == 0 &&
+                           Tensor::zero_point(kernel.output()) == 0);
+  }
+#endif // DIS_QUANT
 }
 
-void Sub::configure()
+void execute_kernel_CircleSub(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
 {
-  LUCI_INTERPRETER_CHECK(!(input1()->element_type() != input2()->element_type()))
-  LUCI_INTERPRETER_CHECK(!(input1()->element_type() != output()->element_type()))
-  // TODO: enable it only if kernel with dynamic shapes
-  output()->resize(calculateShapeForBroadcast(input1()->shape(), input2()->shape()));
-}
+  kernels::TISOKernel kernel(cur_op, runtime_graph);
 
-void Sub::execute() const
-{
-  switch (input1()->element_type())
+  const auto *options = cur_op->builtin_options_as_SubOptions();
+
+  luci_interpreter::RuntimeShape input_shape1 =
+    kernels::getTensorRuntimeShape(kernel.input1(), runtime_graph);
+  luci_interpreter::RuntimeShape input_shape2 =
+    kernels::getTensorRuntimeShape(kernel.input2(), runtime_graph);
+
+  bool is_inplace = runtime_graph->is_inplace_op(cur_op);
+
+  switch (Tensor::element_type(kernel.input1()))
   {
+#ifndef DIS_FLOAT
     case DataType::FLOAT32:
-      evalFloat();
-      break;
+    {
+      auto tiso_func = luci_interpreter_pal::Sub<float>;
+
+      auto broadcast_tiso_func = luci_interpreter_pal::BroadcastSub4DSlow<float>;
+      if (is_inplace)
+      {
+        kernels::evalTISOInplaceKernel<float>(tiso_func, broadcast_tiso_func, &kernel, options,
+                                              std::move(input_shape1), std::move(input_shape2));
+      }
+      else
+      {
+        kernels::TISOData kernel_data = kernel.readData();
+        kernels::evalTISOKernel<float>(tiso_func, broadcast_tiso_func, &kernel, &kernel_data,
+                                       options, std::move(input_shape1), std::move(input_shape2));
+      }
+    }
+    break;
+#endif // DIS_FLOAT
     case DataType::S64:
-      evalInteger<int64_t>();
-      break;
+    {
+      auto tiso_func = luci_interpreter_pal::Sub<int64_t>;
+
+      auto broadcast_tiso_func = luci_interpreter_pal::BroadcastSub4DSlow<int64_t>;
+
+      if (is_inplace)
+      {
+        kernels::evalTISOInplaceKernel<int64_t>(tiso_func, broadcast_tiso_func, &kernel, options,
+                                                std::move(input_shape1), std::move(input_shape2));
+      }
+      else
+      {
+        kernels::TISOData kernel_data = kernel.readData();
+        kernels::evalTISOKernel<int64_t>(tiso_func, broadcast_tiso_func, &kernel, &kernel_data,
+                                         options, std::move(input_shape1), std::move(input_shape2));
+      }
+    }
+    break;
     case DataType::S32:
-      evalInteger<int32_t>();
-      break;
+    {
+      auto tiso_func = luci_interpreter_pal::Sub<int32_t>;
+
+      auto broadcast_tiso_func = luci_interpreter_pal::BroadcastSub4DSlow<int32_t>;
+
+      if (is_inplace)
+      {
+        kernels::evalTISOInplaceKernel<int32_t>(tiso_func, broadcast_tiso_func, &kernel, options,
+                                                std::move(input_shape1), std::move(input_shape2));
+      }
+      else
+      {
+        kernels::TISOData kernel_data = kernel.readData();
+        kernels::evalTISOKernel<int32_t>(tiso_func, broadcast_tiso_func, &kernel, &kernel_data,
+                                         options, std::move(input_shape1), std::move(input_shape2));
+      }
+    }
+    break;
+// TODO: fix it
+#if 0
+#ifndef DIS_QUANT
     case DataType::U8:
-      evalQuantized();
-      break;
+    {
+      auto tiso_func = [](const tflite::ArithmeticParams &params,
+                          const tflite::RuntimeShape &input1_shape, const uint8_t *input1_data,
+                          const tflite::RuntimeShape &input2_shape, const uint8_t *input2_data,
+                          const tflite::RuntimeShape &output_shape, uint8_t *output_data) {
+        tflite::reference_ops::Sub(params, input1_shape, input1_data, input2_shape, input2_data,
+                                   output_shape, output_data);
+      };
+      auto broadcast_tiso_func =
+        [](const tflite::ArithmeticParams &params, const tflite::RuntimeShape &input1_shape,
+           const uint8_t *input1_data, const tflite::RuntimeShape &input2_shape,
+           const uint8_t *input2_data, const tflite::RuntimeShape &output_shape,
+           uint8_t *output_data) {
+          tflite::reference_ops::BroadcastSubSlow(params, input1_shape, input1_data, input2_shape,
+                                                  input2_data, output_shape, output_data);
+        };
+      if (is_inplace)
+      {
+        kernels::evalTISOInplaceQuantizedKernel<uint8_t>(tiso_func, broadcast_tiso_func, &kernel,
+                                                         options);
+      }
+      else
+      {
+        kernels::TISOData kernel_data = kernel.readData();
+        kernels::evalTISOQuantizedKernel<uint8_t>(tiso_func, broadcast_tiso_func, &kernel,
+                                                  &kernel_data, options);
+      }
+    }
+    break;
+#endif // DIS_QUANT
+#endif // 0
     default:
       assert(false && "Unsupported type.");
   }
 }
 
-void Sub::evalFloat() const
-{
-  tflite::ArithmeticParams params{};
-  fillArithmeticActivationRange<float>(params, _params.activation);
-
-  const bool need_broadcast = tflite::reference_ops::ProcessBroadcastShapes(
-    getTensorShape(input1()), getTensorShape(input2()), &params);
-
-  if (need_broadcast)
-  {
-    tflite::reference_ops::BroadcastSubSlow(
-      params, getTensorShape(input1()), getTensorData<float>(input1()), getTensorShape(input2()),
-      getTensorData<float>(input2()), getTensorShape(output()), getTensorData<float>(output()));
-  }
-  else
-  {
-    luci_interpreter_pal::Sub(params, getTensorShape(input1()), getTensorData<float>(input1()),
-                              getTensorShape(input2()), getTensorData<float>(input2()),
-                              getTensorShape(output()), getTensorData<float>(output()));
-  }
-}
-
-template <typename T> void Sub::evalInteger() const
-{
-  tflite::ArithmeticParams params{};
-  fillArithmeticActivationRange<T>(params, _params.activation);
-
-  const bool need_broadcast = tflite::reference_ops::ProcessBroadcastShapes(
-    getTensorShape(input1()), getTensorShape(input2()), &params);
-
-  if (need_broadcast)
-  {
-    tflite::reference_ops::BroadcastSubSlow(
-      params, getTensorShape(input1()), getTensorData<T>(input1()), getTensorShape(input2()),
-      getTensorData<T>(input2()), getTensorShape(output()), getTensorData<T>(output()));
-  }
-  else
-  {
-    tflite::reference_ops::Sub(params, getTensorShape(input1()), getTensorData<T>(input1()),
-                               getTensorShape(input2()), getTensorData<T>(input2()),
-                               getTensorShape(output()), getTensorData<T>(output()));
-  }
-}
-
-void Sub::evalQuantized() const
-{
-  const auto input1_scale = static_cast<double>(input1()->scale());
-  const auto input2_scale = static_cast<double>(input2()->scale());
-  const auto output_scale = static_cast<double>(output()->scale());
-
-  const int left_shift = 20;
-  const double twice_max_input_scale = 2 * std::max(input1_scale, input2_scale);
-  const double real_input1_multiplier = input1_scale / twice_max_input_scale;
-  const double real_input2_multiplier = input2_scale / twice_max_input_scale;
-  const double real_output_multiplier = twice_max_input_scale / ((1 << left_shift) * output_scale);
-
-  int32_t input1_multiplier{}, input2_multiplier{}, output_multiplier{};
-  int input1_shift{}, input2_shift{}, output_shift{};
-  quantizeMultiplierSmallerThanOneExp(real_input1_multiplier, &input1_multiplier, &input1_shift);
-  quantizeMultiplierSmallerThanOneExp(real_input2_multiplier, &input2_multiplier, &input2_shift);
-  quantizeMultiplierSmallerThanOneExp(real_output_multiplier, &output_multiplier, &output_shift);
-
-  int32_t activation_min{};
-  int32_t activation_max{};
-  calculateActivationRangeQuantized(_params.activation, output(), &activation_min, &activation_max);
-
-  tflite::ArithmeticParams params{};
-  params.left_shift = left_shift;
-  // The kernel expects inputs' zero points to be negated.
-  params.input1_offset = -input1()->zero_point(); // Note the '-'.
-  params.input1_multiplier = input1_multiplier;
-  params.input1_shift = input1_shift;
-  params.input2_offset = -input2()->zero_point(); // Note the '-'.
-  params.input2_multiplier = input2_multiplier;
-  params.input2_shift = input2_shift;
-  params.output_offset = output()->zero_point();
-  params.output_multiplier = output_multiplier;
-  params.output_shift = output_shift;
-  params.quantized_activation_min = activation_min;
-  params.quantized_activation_max = activation_max;
-
-  const bool need_broadcast = tflite::reference_ops::ProcessBroadcastShapes(
-    getTensorShape(input1()), getTensorShape(input2()), &params);
-
-  if (need_broadcast)
-  {
-    tflite::reference_ops::BroadcastSubSlow(
-      params, getTensorShape(input1()), getTensorData<uint8_t>(input1()), getTensorShape(input2()),
-      getTensorData<uint8_t>(input2()), getTensorShape(output()), getTensorData<uint8_t>(output()));
-  }
-  else
-  {
-    tflite::reference_ops::Sub(params, getTensorShape(input1()), getTensorData<uint8_t>(input1()),
-                               getTensorShape(input2()), getTensorData<uint8_t>(input2()),
-                               getTensorShape(output()), getTensorData<uint8_t>(output()));
-  }
-}
-
-} // namespace kernels
 } // namespace luci_interpreter