Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / luci-interpreter / src / kernels / ArgMax.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "kernels/ArgMax.h"
18 #include "kernels/Utils.h"
19 #include "PALArgMax.h"
20
21 namespace luci_interpreter
22 {
23 namespace kernels
24 {
25
26 ArgMax::ArgMax(const Tensor *input, const Tensor *axis, Tensor *output, const ArgMaxParams &params)
27   : KernelWithParams<ArgMaxParams>({input, axis}, {output}, params)
28 {
29 }
30
31 void ArgMax::configure()
32 {
33   assert(axis()->element_type() == DataType::S32 || axis()->element_type() == DataType::S64);
34   assert(input()->shape().num_dims() >= 1);
35   const Shape &input_shape = input()->shape();
36   const int num_dims = input_shape.num_dims();
37   Shape output_shape(num_dims - 1);
38
39   // If axis value is negative, then update by adding input_shape's num_dims.
40   // If updated value also negative, then assert.
41   assert(axis()->shape().num_elements() == 1);
42   int axis_value = getTensorData<int32_t>(axis())[0];
43   if (axis_value < 0)
44     axis_value = axis_value + num_dims;
45   assert(axis_value >= 0);
46
47   int j = 0;
48   for (int i = 0; i < num_dims; i++)
49   {
50     if (i == axis_value)
51       continue;
52     output_shape.dim(j++) = input_shape.dim(i);
53   }
54
55   assert(output()->element_type() == _params.output_type);
56
57   output()->resize(output_shape);
58 }
59
60 void ArgMax::execute() const
61 {
62
63 #define TF_LITE_ARG_MAX(data_type, axis_type, output_type)                                    \
64   luci_interpreter_pal::ArgMinMax(getTensorShape(input()), getTensorData<data_type>(input()), \
65                                   getTensorData<axis_type>(axis()), getTensorShape(output()), \
66                                   getTensorData<output_type>(output()), std::greater<data_type>())
67   if (axis()->element_type() == DataType::S32)
68   {
69     switch (_params.output_type)
70     {
71       case DataType::S32:
72         switch (input()->element_type())
73         {
74           case DataType::FLOAT32:
75             TF_LITE_ARG_MAX(float, int32_t, int32_t);
76             break;
77           case DataType::U8:
78             TF_LITE_ARG_MAX(uint8_t, int32_t, int32_t);
79             break;
80           default:
81             throw std::runtime_error("Unsupported input type.");
82         }
83         break;
84       case DataType::S64:
85         switch (input()->element_type())
86         {
87           case DataType::FLOAT32:
88             TF_LITE_ARG_MAX(float, int32_t, int64_t);
89             break;
90           case DataType::U8:
91             TF_LITE_ARG_MAX(uint8_t, int32_t, int64_t);
92             break;
93           default:
94             throw std::runtime_error("Unsupported input type.");
95         }
96         break;
97       default:
98         throw std::runtime_error("Unsupported output type.");
99     }
100   }
101   else
102   {
103     switch (_params.output_type)
104     {
105       case DataType::S32:
106         switch (input()->element_type())
107         {
108           case DataType::FLOAT32:
109             TF_LITE_ARG_MAX(float, int64_t, int32_t);
110             break;
111           case DataType::U8:
112             TF_LITE_ARG_MAX(uint8_t, int64_t, int32_t);
113             break;
114           default:
115             throw std::runtime_error("Unsupported input type.");
116         }
117         break;
118       case DataType::S64:
119         switch (input()->element_type())
120         {
121           case DataType::FLOAT32:
122             TF_LITE_ARG_MAX(float, int64_t, int64_t);
123             break;
124           case DataType::U8:
125             TF_LITE_ARG_MAX(uint8_t, int64_t, int64_t);
126             break;
127           default:
128             throw std::runtime_error("Unsupported input type.");
129         }
130         break;
131       default:
132         throw std::runtime_error("Unsupported output type.");
133     }
134   }
135 #undef TF_LITE_ARG_MAX
136 }
137
138 } // namespace kernels
139 } // namespace luci_interpreter