* limitations under the License.
*/
-#include "kernels/ArgMax.h"
+#include "Builders.h"
#include "kernels/Utils.h"
-#include "PALArgMax.h"
+#include "TISOKernel.h"
+
+#include "PALArgMinMax.h"
namespace luci_interpreter
{
-namespace kernels
-{
-ArgMax::ArgMax(const Tensor *input, const Tensor *axis, Tensor *output, const ArgMaxParams ¶ms)
- : KernelWithParams<ArgMaxParams>({input, axis}, {output}, params)
+void configure_kernel_CircleArgMax(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
{
+ kernels::TISOKernel kernel(cur_op, runtime_graph);
+ // dim tensor must be a scalar or has one element
+ LUCI_INTERPRETER_CHECK(Tensor::num_dims(kernel.input2()) == 0 or
+ Tensor::num_elements(kernel.input2()) == 1);
+ // value and output type must match
+ LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.output()) == DataType::S32);
+
+ LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input2()) == DataType::S32);
}
-void ArgMax::configure()
+void execute_kernel_CircleArgMax(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
{
- assert(axis()->element_type() == DataType::S32 || axis()->element_type() == DataType::S64);
- assert(input()->shape().num_dims() >= 1);
- const Shape &input_shape = input()->shape();
- const int num_dims = input_shape.num_dims();
- Shape output_shape(num_dims - 1);
-
- // If axis value is negative, then update by adding input_shape's num_dims.
- // If updated value also negative, then assert.
- assert(axis()->shape().num_elements() == 1);
- int axis_value = getTensorData<int32_t>(axis())[0];
- if (axis_value < 0)
- axis_value = axis_value + num_dims;
- assert(axis_value >= 0);
-
- int j = 0;
- for (int i = 0; i < num_dims; i++)
- {
- if (i == axis_value)
- continue;
- output_shape.dim(j++) = input_shape.dim(i);
- }
+ kernels::TISOKernel kernel(cur_op, runtime_graph);
- assert(output()->element_type() == _params.output_type);
+ const circle::Tensor *input = kernel.input1();
+ const circle::Tensor *output = kernel.output();
- // TODO: enable it only if kernel with dynamic shapes
- output()->resize(output_shape);
-}
-
-void ArgMax::execute() const
-{
+ kernels::TISOData tiso_data = kernel.readData();
+ const auto input_data = tiso_data.input1_data;
+ const auto axis_data = tiso_data.input2_data;
+ auto output_data = tiso_data.output_data;
-#define TF_LITE_ARG_MAX(data_type, axis_type, output_type) \
- luci_interpreter_pal::ArgMinMax(getTensorShape(input()), getTensorData<data_type>(input()), \
- getTensorData<axis_type>(axis()), getTensorShape(output()), \
- getTensorData<output_type>(output()), std::greater<data_type>())
- if (axis()->element_type() == DataType::S32)
- {
- switch (_params.output_type)
- {
- case DataType::S32:
- switch (input()->element_type())
- {
- case DataType::FLOAT32:
- TF_LITE_ARG_MAX(float, int32_t, int32_t);
- break;
- case DataType::U8:
- TF_LITE_ARG_MAX(uint8_t, int32_t, int32_t);
- break;
- default:
- assert(false && "Unsupported input type.");
- }
- break;
- case DataType::S64:
- switch (input()->element_type())
- {
- case DataType::FLOAT32:
- TF_LITE_ARG_MAX(float, int32_t, int64_t);
- break;
- case DataType::U8:
- TF_LITE_ARG_MAX(uint8_t, int32_t, int64_t);
- break;
- default:
- assert(false && "Unsupported input type.");
- }
- break;
- default:
- assert(false && "Unsupported output type.");
- }
- }
- else
+ switch (Tensor::element_type(input))
{
- switch (_params.output_type)
+#ifndef DIS_FLOAT
+ case DataType::FLOAT32:
{
- case DataType::S32:
- switch (input()->element_type())
- {
- case DataType::FLOAT32:
- TF_LITE_ARG_MAX(float, int64_t, int32_t);
- break;
- case DataType::U8:
- TF_LITE_ARG_MAX(uint8_t, int64_t, int32_t);
- break;
- default:
- assert(false && "Unsupported input type.");
- }
- break;
- case DataType::S64:
- switch (input()->element_type())
- {
- case DataType::FLOAT32:
- TF_LITE_ARG_MAX(float, int64_t, int64_t);
- break;
- case DataType::U8:
- TF_LITE_ARG_MAX(uint8_t, int64_t, int64_t);
- break;
- default:
- assert(false && "Unsupported input type.");
- }
- break;
- default:
- assert(false && "Unsupported output type.");
+ luci_interpreter_pal::ArgMinMax(
+ kernels::getTensorRuntimeShape(input, runtime_graph),
+ kernels::getTensorData<float>(input_data), kernels::getTensorData<int32_t>(axis_data),
+ kernels::getTensorRuntimeShape(output, runtime_graph),
+ kernels::getTensorData<int32_t>(output_data), std::greater<float>());
}
+ break;
+#endif // DIS_FLOAT
+ default:
+ assert(false && "Unsupported ArgMax input type");
}
-#undef TF_LITE_ARG_MAX
}
-} // namespace kernels
} // namespace luci_interpreter