Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / Fill.cpp
index 6707c70..8bc5014 100644 (file)
@@ -1,6 +1,6 @@
 /*
  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
- * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  * limitations under the License.
  */
 
-#include "kernels/Fill.h"
+#include "Builders.h"
+#include "TISOKernel.h"
 #include "kernels/Utils.h"
-#include "PALFill.h"
 
 namespace luci_interpreter
 {
-namespace kernels
+namespace
 {
 
-Fill::Fill(const Tensor *dims, const Tensor *value, Tensor *output)
-  : Kernel({dims, value}, {output})
+template <typename T> void fillImpl(const size_t flat_size, const T *value_data, T *output_data)
 {
-}
-
-template <typename T> void Fill::configureShape()
-{
-  const auto dims_data = getTensorData<T>(dims());
-  Shape output_shape(dims()->shape().dim(0));
-
-  for (int i = 0; i < output_shape.num_dims(); ++i)
+  for (int i = 0; i < flat_size; ++i)
   {
-    T data = dims_data[i];
-    if (data < 0)
-      assert(false && "Fill dimensions must be >= 0");
-
-    output_shape.dim(i) = data;
+    output_data[i] = *value_data;
   }
-  // TODO: enable it only if kernel with dynamic shapes
-  output()->resize(output_shape);
 }
 
-void Fill::configure()
+} // namespace
+
+void configure_kernel_CircleFill(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
 {
-  const auto dims_shape = dims()->shape();
-  const auto value_shape = value()->shape();
+  kernels::TISOKernel kernel(cur_op, runtime_graph);
+  // value tensor must be a scalar or has one element
+  LUCI_INTERPRETER_CHECK(Tensor::num_dims(kernel.input2()) == 0 or
+                         Tensor::num_elements(kernel.input2()) == 1);
+  // value and output type must match
+  LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input2()) ==
+                         Tensor::element_type(kernel.output()));
+}
 
-  // Make sure the 1st input tensor is 1-D
-  LUCI_INTERPRETER_CHECK(dims_shape.num_dims() == 1);
+void execute_kernel_CircleFill(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
+{
+  kernels::TISOKernel kernel(cur_op, runtime_graph);
 
-  // Make sure the 1st input tensor is int32 or int64
-  LUCI_INTERPRETER_CHECK(dims()->element_type() == DataType::S32 or
-                         dims()->element_type() == DataType::S64);
+  const circle::Tensor *value = kernel.input2();
+  const circle::Tensor *output = kernel.output();
 
-  // Make sure the 2nd input tensor is a scalar
-  LUCI_INTERPRETER_CHECK(value_shape.num_dims() == 0)
+  kernels::TISOData tiso_data = kernel.readData();
+  const uint8_t *value_data = tiso_data.input2_data;
+  uint8_t *output_data = tiso_data.output_data;
 
-  // Check zero point and scale for S16 and S8
-  if (value()->element_type() == DataType::S16 or value()->element_type() == DataType::S8)
-  {
-    LUCI_INTERPRETER_CHECK(value()->scale() == output()->scale());
-    LUCI_INTERPRETER_CHECK(value()->zero_point() == output()->zero_point());
+  const size_t flat_size = Tensor::num_elements(output);
 
-    if (value()->element_type() == DataType::S16)
-      LUCI_INTERPRETER_CHECK(value()->zero_point() == 0);
-  }
-  // Resize output
-  switch (dims()->element_type())
+  switch (Tensor::element_type(value))
   {
-    case DataType::S32:
-      configureShape<int32_t>();
-      break;
-    case DataType::S64:
-      configureShape<int64_t>();
-      break;
-    default:
-      assert(false && "Unsupported type.");
-  }
-}
-
-void Fill::execute() const
-{
-  switch (output()->element_type())
-  {
-    case DataType::S8:
-      tflite::reference_ops::Fill(getTensorShape(value()), getTensorData<int8_t>(value()),
-                                  getTensorShape(output()), getTensorData<int8_t>(output()));
-      break;
-    case DataType::S16:
-      tflite::reference_ops::Fill(getTensorShape(value()), getTensorData<int16_t>(value()),
-                                  getTensorShape(output()), getTensorData<int16_t>(output()));
+#ifndef DIS_FLOAT
+    case DataType::FLOAT32:
+      fillImpl<float>(flat_size, kernels::getTensorData<float>(value_data),
+                      kernels::getTensorData<float>(output_data));
       break;
+#endif // DIS_FLOAT
     case DataType::S32:
-      tflite::reference_ops::Fill(getTensorShape(value()), getTensorData<int32_t>(value()),
-                                  getTensorShape(output()), getTensorData<int32_t>(output()));
+      fillImpl<int32_t>(flat_size, kernels::getTensorData<int32_t>(value_data),
+                        kernels::getTensorData<int32_t>(output_data));
       break;
-    case DataType::S64:
-      tflite::reference_ops::Fill(getTensorShape(value()), getTensorData<int64_t>(value()),
-                                  getTensorShape(output()), getTensorData<int64_t>(output()));
-      break;
-    case DataType::FLOAT32:
-      tflite::reference_ops::Fill(getTensorShape(value()), getTensorData<float>(value()),
-                                  getTensorShape(output()), getTensorData<float>(output()));
+#ifndef DIS_QUANT
+    case DataType::U8:
+      fillImpl<uint8_t>(flat_size, kernels::getTensorData<uint8_t>(value_data),
+                        kernels::getTensorData<uint8_t>(output_data));
       break;
+#endif // DIS_QUANT
     default:
-      assert(false && "Unsupported type.");
+      assert(false && "Not impl yet");
   }
 }
 
-} // namespace kernels
 } // namespace luci_interpreter