* limitations under the License.
*/
-#include "kernels/Gather.h"
+#include "Builders.h"
#include "kernels/Utils.h"
-#include "PALGather.h"
+#include "TISOKernel.h"
#include <cassert>
namespace luci_interpreter
{
-
-namespace kernels
+namespace
{
-Gather::Gather(const Tensor *params, const Tensor *indices, Tensor *output,
- const GatherParams &gparams)
- : KernelWithParams<GatherParams>({params, indices}, {output}, gparams)
+template <typename InputT, typename CoordsT = int32_t>
+void gather(const circle::GatherOptions *options, kernels::TISOKernel *kernel)
{
-}
-
-void Gather::configure()
-{
- if (params()->element_type() == DataType::FLOAT32)
- {
- LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::FLOAT32);
- }
- else
- {
- assert(false && "Unsupported type.");
- }
-
- LUCI_INTERPRETER_CHECK(indices()->element_type() == DataType::S32 ||
- indices()->element_type() == DataType::S64);
+ kernels::TISOData tiso_data = kernel->readData();
- // refer tensorflow/lite/kernels/gather.cc
+ const InputT *input_data = kernels::getTensorData<InputT>(tiso_data.input1_data);
+ const CoordsT *coords_data = kernels::getTensorData<CoordsT>(tiso_data.input2_data);
+ InputT *output_data = kernels::getTensorData<InputT>(tiso_data.output_data);
- const Shape ¶ms_shape = params()->shape();
- const Shape &indices_shape = indices()->shape();
+ const circle::Tensor *input = kernel->input1();
+ const circle::Tensor *coords = kernel->input2();
- int axis = _params.axis;
+ const int input_dims_size = Tensor::num_dims(input);
+ int axis = options->axis();
if (axis < 0)
{
- axis += params_shape.num_dims();
+ axis += input_dims_size;
}
- LUCI_INTERPRETER_CHECK(0 <= axis && axis < params_shape.num_dims());
- int batch_dims = _params.batch_dims;
- // batch_dims should be in range: [-rank(indices), rank(indices)].
- // Negative batch_dims is added with rank of positions.
+ int batch_dims = options->batch_dims();
+ // batch_dims should be in range: [-rank(coords), rank(coords)].
+ // Negative batch_dims is added with rank of coords.
+ const int coords_dims_size = Tensor::num_dims(coords);
if (batch_dims < 0)
{
- batch_dims += indices_shape.num_dims();
+ batch_dims += coords_dims_size;
}
- LUCI_INTERPRETER_CHECK(batch_dims <= axis);
- LUCI_INTERPRETER_CHECK(0 <= batch_dims && batch_dims < params_shape.num_dims());
- LUCI_INTERPRETER_CHECK(batch_dims <= indices_shape.num_dims());
+
+ const int axis_size = Tensor::dim(input, axis);
+
+ int batch_size = 1;
for (int i = 0; i < batch_dims; ++i)
{
- LUCI_INTERPRETER_CHECK(params_shape.dim(i) == indices_shape.dim(i));
+ batch_size *= Tensor::dim(input, i);
}
-
- const int num_dimensions = params_shape.num_dims() + indices_shape.num_dims() - 1 - batch_dims;
-
- Shape output_shape(num_dimensions);
- int output_index = 0;
- for (int i = 0; i < axis; ++i)
+ int outer_size = 1;
+ for (int i = batch_dims; i < axis; ++i)
{
- output_shape.dim(output_index++) = params_shape.dim(i);
+ outer_size *= Tensor::dim(input, i);
}
- for (int i = batch_dims; i < indices_shape.num_dims(); ++i)
+ int inner_size = 1;
+ for (int i = axis + 1; i < input_dims_size; ++i)
{
- output_shape.dim(output_index++) = indices_shape.dim(i);
+ inner_size *= Tensor::dim(input, i);
}
- for (int i = axis + 1; i < params_shape.num_dims(); ++i)
+ int coord_size = 1;
+ for (int i = batch_dims; i < coords_dims_size; ++i)
{
- output_shape.dim(output_index++) = params_shape.dim(i);
+ coord_size *= Tensor::dim(coords, i);
}
- // TODO: enable it only if kernel with dynamic shapes
- output()->resize(output_shape);
-}
-void Gather::execute() const
-{
- switch (params()->element_type())
+ for (int batch = 0; batch < batch_size; ++batch)
{
- case DataType::FLOAT32:
- evalFloat();
- break;
- default:
- assert(false && "Unsupported type.");
+ for (int outer = 0; outer < outer_size; ++outer)
+ {
+ for (int coord = 0; coord < coord_size; ++coord)
+ {
+ auto x = coords_data[coord];
+ std::memcpy(
+ output_data + (((batch * outer_size) + outer) * coord_size + coord) * inner_size,
+ input_data +
+ (((batch * outer_size) + outer) * axis_size + coords_data[batch * coord_size + coord]) *
+ inner_size,
+ sizeof(InputT) * inner_size);
+ }
+ }
}
}
-void Gather::evalFloat() const
+} // namespace
+
+void configure_kernel_CircleGather(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
{
- assert(indices()->element_type() == DataType::S32 || indices()->element_type() == DataType::S64);
+ kernels::TISOKernel kernel(cur_op, runtime_graph);
- const auto params_data = getTensorData<float>(params());
- auto output_data = getTensorData<float>(output());
+ const auto *options = cur_op->builtin_options_as_GatherOptions();
- tflite::GatherParams tparams;
- tparams.axis = _params.axis;
- tparams.batch_dims = _params.batch_dims;
+ LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input2()) == DataType::S32);
+ LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input1()) == DataType::FLOAT32 or
+ Tensor::element_type(kernel.input1()) == DataType::S8 or
+ Tensor::element_type(kernel.input1()) == DataType::S32);
- if (indices()->element_type() == DataType::S32)
+ int32_t axis = options->axis();
+ int32_t num_dims = Tensor::num_dims(kernel.input1());
+ if (axis < 0)
{
- const auto indices_data = getTensorData<int32_t>(indices());
+ axis += num_dims;
+ }
+
+ LUCI_INTERPRETER_CHECK(axis >= 0 and axis < num_dims);
- luci_interpreter_pal::Gather<float, int32_t>(tparams, getTensorShape(params()), params_data,
- getTensorShape(indices()), indices_data,
- getTensorShape(output()), output_data);
+ int32_t batch_dims = options->batch_dims();
+ int32_t coords_num_dims = Tensor::num_dims(kernel.input2());
+ // batch_dims should be in range: [-rank(coords), rank(coords)].
+ // Negative batch_dims is added with rank of coords.
+ if (batch_dims < 0)
+ {
+ batch_dims += coords_num_dims;
}
- else
+ LUCI_INTERPRETER_CHECK(batch_dims <= axis);
+ LUCI_INTERPRETER_CHECK(batch_dims >= 0 and batch_dims < num_dims);
+ LUCI_INTERPRETER_CHECK(batch_dims <= coords_num_dims);
+ for (int i = 0; i < batch_dims; ++i)
{
- const auto indices_data = getTensorData<int64_t>(indices());
+ LUCI_INTERPRETER_CHECK(Tensor::dim(kernel.input1(), i) == Tensor::dim(kernel.input2(), i));
+ }
+}
+
+void execute_kernel_CircleGather(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
+{
+ kernels::TISOKernel kernel(cur_op, runtime_graph);
- luci_interpreter_pal::Gather<float, int64_t>(tparams, getTensorShape(params()), params_data,
- getTensorShape(indices()), indices_data,
- getTensorShape(output()), output_data);
+ const auto *options = cur_op->builtin_options_as_GatherOptions();
+
+ switch (Tensor::element_type(kernel.input1()))
+ {
+#ifndef DIS_FLOAT
+ case DataType::FLOAT32:
+ return gather<float, int32_t>(options, &kernel);
+#endif // DIS_FLOAT
+#ifndef DIS_QUANT
+ case DataType::S8:
+ return gather<int8_t, int32_t>(options, &kernel);
+#endif // DIS_QUANT
+ case DataType::S32:
+ return gather<int32_t, int32_t>(options, &kernel);
+ default:
+ assert(false && "Unsupported type");
}
}
-} // namespace kernels
} // namespace luci_interpreter