/*
* Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
- * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* limitations under the License.
*/
-#include "kernels/Fill.h"
+#include "Builders.h"
+#include "TISOKernel.h"
#include "kernels/Utils.h"
-#include "PALFill.h"
namespace luci_interpreter
{
-namespace kernels
+namespace
{
-Fill::Fill(const Tensor *dims, const Tensor *value, Tensor *output)
- : Kernel({dims, value}, {output})
+template <typename T> void fillImpl(const size_t flat_size, const T *value_data, T *output_data)
{
-}
-
-template <typename T> void Fill::configureShape()
-{
- const auto dims_data = getTensorData<T>(dims());
- Shape output_shape(dims()->shape().dim(0));
-
- for (int i = 0; i < output_shape.num_dims(); ++i)
+ for (int i = 0; i < flat_size; ++i)
{
- T data = dims_data[i];
- if (data < 0)
- assert(false && "Fill dimensions must be >= 0");
-
- output_shape.dim(i) = data;
+ output_data[i] = *value_data;
}
- // TODO: enable it only if kernel with dynamic shapes
- output()->resize(output_shape);
}
-void Fill::configure()
+} // namespace
+
+void configure_kernel_CircleFill(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
{
- const auto dims_shape = dims()->shape();
- const auto value_shape = value()->shape();
+ kernels::TISOKernel kernel(cur_op, runtime_graph);
+ // value tensor must be a scalar or has one element
+ LUCI_INTERPRETER_CHECK(Tensor::num_dims(kernel.input2()) == 0 or
+ Tensor::num_elements(kernel.input2()) == 1);
+ // value and output type must match
+ LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input2()) ==
+ Tensor::element_type(kernel.output()));
+}
- // Make sure the 1st input tensor is 1-D
- LUCI_INTERPRETER_CHECK(dims_shape.num_dims() == 1);
+void execute_kernel_CircleFill(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
+{
+ kernels::TISOKernel kernel(cur_op, runtime_graph);
- // Make sure the 1st input tensor is int32 or int64
- LUCI_INTERPRETER_CHECK(dims()->element_type() == DataType::S32 or
- dims()->element_type() == DataType::S64);
+ const circle::Tensor *value = kernel.input2();
+ const circle::Tensor *output = kernel.output();
- // Make sure the 2nd input tensor is a scalar
- LUCI_INTERPRETER_CHECK(value_shape.num_dims() == 0)
+ kernels::TISOData tiso_data = kernel.readData();
+ const uint8_t *value_data = tiso_data.input2_data;
+ uint8_t *output_data = tiso_data.output_data;
- // Check zero point and scale for S16 and S8
- if (value()->element_type() == DataType::S16 or value()->element_type() == DataType::S8)
- {
- LUCI_INTERPRETER_CHECK(value()->scale() == output()->scale());
- LUCI_INTERPRETER_CHECK(value()->zero_point() == output()->zero_point());
+ const size_t flat_size = Tensor::num_elements(output);
- if (value()->element_type() == DataType::S16)
- LUCI_INTERPRETER_CHECK(value()->zero_point() == 0);
- }
- // Resize output
- switch (dims()->element_type())
+ switch (Tensor::element_type(value))
{
- case DataType::S32:
- configureShape<int32_t>();
- break;
- case DataType::S64:
- configureShape<int64_t>();
- break;
- default:
- assert(false && "Unsupported type.");
- }
-}
-
-void Fill::execute() const
-{
- switch (output()->element_type())
- {
- case DataType::S8:
- tflite::reference_ops::Fill(getTensorShape(value()), getTensorData<int8_t>(value()),
- getTensorShape(output()), getTensorData<int8_t>(output()));
- break;
- case DataType::S16:
- tflite::reference_ops::Fill(getTensorShape(value()), getTensorData<int16_t>(value()),
- getTensorShape(output()), getTensorData<int16_t>(output()));
+#ifndef DIS_FLOAT
+ case DataType::FLOAT32:
+ fillImpl<float>(flat_size, kernels::getTensorData<float>(value_data),
+ kernels::getTensorData<float>(output_data));
break;
+#endif // DIS_FLOAT
case DataType::S32:
- tflite::reference_ops::Fill(getTensorShape(value()), getTensorData<int32_t>(value()),
- getTensorShape(output()), getTensorData<int32_t>(output()));
+ fillImpl<int32_t>(flat_size, kernels::getTensorData<int32_t>(value_data),
+ kernels::getTensorData<int32_t>(output_data));
break;
- case DataType::S64:
- tflite::reference_ops::Fill(getTensorShape(value()), getTensorData<int64_t>(value()),
- getTensorShape(output()), getTensorData<int64_t>(output()));
- break;
- case DataType::FLOAT32:
- tflite::reference_ops::Fill(getTensorShape(value()), getTensorData<float>(value()),
- getTensorShape(output()), getTensorData<float>(output()));
+#ifndef DIS_QUANT
+ case DataType::U8:
+ fillImpl<uint8_t>(flat_size, kernels::getTensorData<uint8_t>(value_data),
+ kernels::getTensorData<uint8_t>(output_data));
break;
+#endif // DIS_QUANT
default:
- assert(false && "Unsupported type.");
+ assert(false && "Not impl yet");
}
}
-} // namespace kernels
} // namespace luci_interpreter