5ac4bcb1d0f9259ffec9df32a511e642ea7da825
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / ArgMax.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "kernels/ArgMax.h"
19 #include "kernels/Utils.h"
20 #include "PALArgMax.h"
21
22 namespace luci_interpreter
23 {
24 namespace kernels
25 {
26
27 ArgMax::ArgMax(const Tensor *input, const Tensor *axis, Tensor *output, const ArgMaxParams &params)
28   : KernelWithParams<ArgMaxParams>({input, axis}, {output}, params)
29 {
30 }
31
32 void ArgMax::configure()
33 {
34   assert(axis()->element_type() == DataType::S32 || axis()->element_type() == DataType::S64);
35   assert(input()->shape().num_dims() >= 1);
36   const Shape &input_shape = input()->shape();
37   const int num_dims = input_shape.num_dims();
38   Shape output_shape(num_dims - 1);
39
40   // If axis value is negative, then update by adding input_shape's num_dims.
41   // If updated value also negative, then assert.
42   assert(axis()->shape().num_elements() == 1);
43   int axis_value = getTensorData<int32_t>(axis())[0];
44   if (axis_value < 0)
45     axis_value = axis_value + num_dims;
46   assert(axis_value >= 0);
47
48   int j = 0;
49   for (int i = 0; i < num_dims; i++)
50   {
51     if (i == axis_value)
52       continue;
53     output_shape.dim(j++) = input_shape.dim(i);
54   }
55
56   assert(output()->element_type() == _params.output_type);
57
58   // TODO: enable it only if kernel with dynamic shapes
59   output()->resize(output_shape);
60 }
61
62 void ArgMax::execute() const
63 {
64
65 #define TF_LITE_ARG_MAX(data_type, axis_type, output_type)                                    \
66   luci_interpreter_pal::ArgMinMax(getTensorShape(input()), getTensorData<data_type>(input()), \
67                                   getTensorData<axis_type>(axis()), getTensorShape(output()), \
68                                   getTensorData<output_type>(output()), std::greater<data_type>())
69   if (axis()->element_type() == DataType::S32)
70   {
71     switch (_params.output_type)
72     {
73       case DataType::S32:
74         switch (input()->element_type())
75         {
76           case DataType::FLOAT32:
77             TF_LITE_ARG_MAX(float, int32_t, int32_t);
78             break;
79           case DataType::U8:
80             TF_LITE_ARG_MAX(uint8_t, int32_t, int32_t);
81             break;
82           default:
83             assert(false && "Unsupported input type.");
84         }
85         break;
86       case DataType::S64:
87         switch (input()->element_type())
88         {
89           case DataType::FLOAT32:
90             TF_LITE_ARG_MAX(float, int32_t, int64_t);
91             break;
92           case DataType::U8:
93             TF_LITE_ARG_MAX(uint8_t, int32_t, int64_t);
94             break;
95           default:
96             assert(false && "Unsupported input type.");
97         }
98         break;
99       default:
100         assert(false && "Unsupported output type.");
101     }
102   }
103   else
104   {
105     switch (_params.output_type)
106     {
107       case DataType::S32:
108         switch (input()->element_type())
109         {
110           case DataType::FLOAT32:
111             TF_LITE_ARG_MAX(float, int64_t, int32_t);
112             break;
113           case DataType::U8:
114             TF_LITE_ARG_MAX(uint8_t, int64_t, int32_t);
115             break;
116           default:
117             assert(false && "Unsupported input type.");
118         }
119         break;
120       case DataType::S64:
121         switch (input()->element_type())
122         {
123           case DataType::FLOAT32:
124             TF_LITE_ARG_MAX(float, int64_t, int64_t);
125             break;
126           case DataType::U8:
127             TF_LITE_ARG_MAX(uint8_t, int64_t, int64_t);
128             break;
129           default:
130             assert(false && "Unsupported input type.");
131         }
132         break;
133       default:
134         assert(false && "Unsupported output type.");
135     }
136   }
137 #undef TF_LITE_ARG_MAX
138 }
139
140 } // namespace kernels
141 } // namespace luci_interpreter