#include "Builders.h"
#include "kernels/Utils.h"
-#include <tensorflow/lite/kernels/internal/reference/concatenation.h>
+#include "PALConcatenation.h"
namespace luci_interpreter
{
{
template <typename T>
-void evalGeneric(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph, bool)
+void evalGeneric(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
{
const auto output_index = cur_op->outputs()->operator[](0);
const auto input_sizes = cur_op->inputs()->size();
std::vector<const T *> all_input_data;
- std::vector<tflite::RuntimeShape> all_shape;
- std::vector<tflite::RuntimeShape *> all_shape_ptr;
-
- all_input_data.reserve(input_sizes);
- all_shape.reserve(input_sizes);
- all_shape_ptr.reserve(input_sizes);
+ std::vector<luci_interpreter::RuntimeShape> all_shape;
+ std::vector<luci_interpreter::RuntimeShape *> all_shape_ptr;
for (int32_t i = 0; i < input_sizes; ++i)
{
auto input_index = cur_op->inputs()->operator[](i);
const auto *tensor = runtime_graph->getCircleTensorByIndex(input_index);
- auto *data = reinterpret_cast<const T *>(runtime_graph->getDataByTensor(tensor));
+ const auto *tensor_data = runtime_graph->getDataByTensor(tensor);
+ if (tensor_data == nullptr)
+ tensor_data = runtime_graph->getConstDataByTensor(tensor);
+
+ auto *data = reinterpret_cast<const T *>(tensor_data);
+
+ auto runtime_shape = kernels::getTensorRuntimeShape(tensor, runtime_graph);
all_input_data.push_back(data);
- all_shape.push_back(kernels::getTensorShape(tensor));
+ all_shape.push_back(runtime_shape);
}
- for (tflite::RuntimeShape &shape : all_shape)
+ for (luci_interpreter::RuntimeShape &shape : all_shape)
{
all_shape_ptr.push_back(&shape);
}
auto *output_data = reinterpret_cast<T *>(runtime_graph->getDataByTensor(output));
- // kernels::VectorOfTensors<T, true> inputs(_inputs);
- tflite::ConcatenationParams params{};
+ luci_interpreter_pal::ConcatenationParams params{};
params.axis = axis;
- params.inputs_count = input_sizes;
- tflite::reference_ops::Concatenation(params, all_shape_ptr.data(), all_input_data.data(),
- kernels::getTensorShape(output), output_data);
+ params.inputs_count = all_shape.size();
+ luci_interpreter_pal::Concatenation(params, all_shape_ptr.data(), all_input_data.data(),
+ kernels::getTensorShape(output), output_data);
}
} // namespace
axis += Tensor::num_dims(t0);
LUCI_INTERPRETER_CHECK(axis >= 0 && axis < Tensor::num_dims(t0));
- int32_t sum_axis = Tensor::dim(t0, axis);
for (int i = 1; i < num_inputs; ++i)
{
input_index = cur_op->inputs()->operator[](i);
const auto *tensor = runtime_graph->getCircleTensorByIndex(input_index);
LUCI_INTERPRETER_CHECK(Tensor::element_type(tensor) == Tensor::element_type(t0));
LUCI_INTERPRETER_CHECK(Tensor::num_dims(tensor) == Tensor::num_dims(t0));
- for (int d = 0; d < Tensor::num_dims(t0); ++d)
- {
- if (d == axis)
- {
- sum_axis += Tensor::dim(tensor, axis);
- }
- else
- {
- LUCI_INTERPRETER_CHECK(Tensor::dim(tensor, d) == Tensor::dim(t0, d));
- }
- }
}
#ifndef DIS_QUANT
}
void execute_kernel_CircleConcatenation(const circle::Operator *cur_op,
- BaseRuntimeGraph *runtime_graph, bool is_inplace)
+ BaseRuntimeGraph *runtime_graph)
{
int num_inputs = cur_op->inputs()->size();
LUCI_INTERPRETER_CHECK(num_inputs > 0);
{
#ifndef DIS_FLOAT
case DataType::FLOAT32:
- evalGeneric<float>(cur_op, runtime_graph, is_inplace);
+ evalGeneric<float>(cur_op, runtime_graph);
break;
#endif // DIS_FLOAT
#ifndef DIS_QUANT
case DataType::S8:
- evalGeneric<int8_t>(cur_op, runtime_graph, is_inplace);
+ evalGeneric<int8_t>(cur_op, runtime_graph);
break;
+#endif // DIS_QUANT
case DataType::S32:
- evalGeneric<int32_t>(cur_op, runtime_graph, is_inplace);
+ evalGeneric<int32_t>(cur_op, runtime_graph);
break;
case DataType::S64:
- evalGeneric<int64_t>(cur_op, runtime_graph, is_inplace);
+ evalGeneric<int64_t>(cur_op, runtime_graph);
break;
-#endif
default:
assert(false && "Unsupported type.");
}