* limitations under the License.
*/
-#include "kernels/Pack.h"
-#include "kernels/Utils.h"
+#include "Builders.h"
+#include "Utils.h"
-#include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
+#include <cassert>
namespace luci_interpreter
{
-namespace kernels
+namespace
{
-Pack::Pack(std::vector<const Tensor *> inputs, Tensor *output, const PackParams ¶ms)
- : KernelWithParams<PackParams>(std::move(inputs), {output}, params)
+template <typename T>
+void packImpl(const circle::Tensor *input0, const circle::Tensor *output,
+ const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph,
+ uint8_t *output_data_raw)
{
-}
+ const auto *options = cur_op->builtin_options_as_PackOptions();
+
+ const int values_count = options->values_count();
+ int axis = options->axis();
+ const int dimensions = Tensor::num_dims(output);
+
+ const auto input_dims = wrap(input0->shape());
+ const auto output_dims = wrap(output->shape());
-void Pack::configure()
-{
- LUCI_INTERPRETER_CHECK(_inputs.size() == static_cast<uint32_t>(params().values_count));
- const Tensor *t0 = _inputs[0];
- const int dimension_size = t0->shape().num_dims() + 1;
- int axis = params().axis;
if (axis < 0)
{
- axis += dimension_size;
+ axis += dimensions;
}
- LUCI_INTERPRETER_CHECK(axis >= 0 && axis <= t0->shape().num_dims());
- if (t0->element_type() != DataType::S32 && t0->element_type() != DataType::FLOAT32 &&
- t0->element_type() != DataType::U8 && t0->element_type() != DataType::S8 &&
- t0->element_type() != DataType::S16 && t0->element_type() != DataType::S64)
- {
- assert(false && "Unsupported type.");
- }
+ int outer_size = 1;
+ for (int i = 0; i < axis; ++i)
+ outer_size *= output_dims[i];
- for (uint32_t i = 1; i < _inputs.size(); ++i)
- {
- const Tensor *tensor = _inputs[i];
- LUCI_INTERPRETER_CHECK(tensor->element_type() == t0->element_type());
- LUCI_INTERPRETER_CHECK(tensor->shape().num_dims() == t0->shape().num_dims());
- for (int d = 0; d < t0->shape().num_dims(); ++d)
- {
- LUCI_INTERPRETER_CHECK(tensor->shape().dim(d) == t0->shape().dim(d));
- }
- }
+ int copy_size = 1;
+ for (int i = axis + 1; i < dimensions; ++i)
+ copy_size *= output_dims[i];
- Shape output_shape(dimension_size);
- int i = 0;
- for (int index = 0; index < dimension_size; ++index)
- {
- if (index == axis)
- {
- output_shape.dim(index) = params().values_count;
- }
- else
- {
- output_shape.dim(index) = t0->shape().dim(i++);
- }
- }
+ int input_size = 1;
+ for (int i = 0; i < input_dims.size(); ++i)
+ input_size *= input_dims[i];
+
+ assert(input_size == copy_size * outer_size);
+
+ T *output_data = kernels::getTensorData<T>(output_data_raw);
+ assert(output_data != nullptr);
- if (t0->element_type() == DataType::U8 || t0->element_type() == DataType::S8 ||
- t0->element_type() == DataType::S16)
+ for (int i = 0; i < values_count; ++i)
{
- LUCI_INTERPRETER_CHECK(output()->zero_point() == t0->zero_point());
- LUCI_INTERPRETER_CHECK(output()->scale() == t0->scale());
- // Guarantee input/output quantization params match as we do not support
- // packing quantized tensors.
- for (int i = 0; i < params().values_count; i++)
+ const auto input_index = cur_op->inputs()->operator[](i);
+ assert(input_index != -1);
+ const auto input = runtime_graph->getCircleTensorByIndex(input_index);
+
+ auto input_data = kernels::getTensorData<T>(runtime_graph->getDataByTensor(input));
+ assert(input_data != nullptr);
+ for (int k = 0; k < outer_size; ++k)
{
- LUCI_INTERPRETER_CHECK(_inputs[i]->zero_point() == t0->zero_point());
- LUCI_INTERPRETER_CHECK(_inputs[i]->scale() == t0->scale());
+ const T *input_ptr = input_data + copy_size * k;
+ int loc = k * values_count * copy_size + i * copy_size;
+ T *output_ptr = output_data + loc;
+ for (int j = 0; j < copy_size; ++j)
+ output_ptr[j] = input_ptr[j];
}
}
- // TODO: enable it only if kernel with dynamic shapes
- output()->resize(output_shape);
}
-void Pack::execute() const
+} // namespace
+
+void configure_kernel_CirclePack(const circle::Operator *, BaseRuntimeGraph *)
+{
+ // Do nothing
+}
+
+void execute_kernel_CirclePack(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
{
- switch (_inputs[0]->element_type())
+ const auto input_index = cur_op->inputs()->operator[](0);
+ const auto output_index = cur_op->outputs()->operator[](0);
+ assert(output_index != -1);
+ assert(input_index != -1);
+ const auto input = runtime_graph->getCircleTensorByIndex(input_index);
+ const auto output = runtime_graph->getCircleTensorByIndex(output_index);
+
+ auto output_data = runtime_graph->getDataByTensor(output);
+ assert(output_data != nullptr);
+
+ switch (Tensor::element_type(output))
{
+#ifndef DIS_FLOAT
case DataType::FLOAT32:
- evalGeneric<float>();
- break;
- case DataType::U8:
- evalGeneric<uint8_t>();
+ packImpl<float>(input, output, cur_op, runtime_graph, output_data);
break;
+#endif // DIS_FLOAT
+#ifndef DIS_QUANT
case DataType::S8:
- evalGeneric<int8_t>();
+ packImpl<int8_t>(input, output, cur_op, runtime_graph, output_data);
break;
- case DataType::S16:
- evalGeneric<int16_t>();
+ case DataType::U8:
+ packImpl<uint8_t>(input, output, cur_op, runtime_graph, output_data);
break;
+#endif // DIS_QUANT
case DataType::S32:
- evalGeneric<int32_t>();
+ packImpl<int32_t>(input, output, cur_op, runtime_graph, output_data);
break;
case DataType::S64:
- evalGeneric<int64_t>();
+ packImpl<int64_t>(input, output, cur_op, runtime_graph, output_data);
break;
default:
- assert(false && "Unsupported type.");
- }
-}
-
-template <typename T> void Pack::evalGeneric() const
-{
- const Tensor *t0 = _inputs[0];
- const int dimension_size = t0->shape().num_dims() + 1;
- int axis = params().axis;
- if (axis < 0)
- {
- axis += dimension_size;
+ assert(false && "Unsupported types");
}
-
- VectorOfTensors<T, true> inputs(_inputs);
- tflite::PackParams params{};
- params.axis = axis;
- params.inputs_count = _inputs.size();
- tflite::reference_ops::Pack<T>(params, inputs.shapes(), inputs.data(), getTensorShape(output()),
- getTensorData<T>(output()));
}
-} // namespace kernels
} // namespace luci_interpreter