Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / compiler / luci-interpreter / src / kernels / ArgMax.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "kernels/ArgMax.h"
18 #include "kernels/Utils.h"
19 #include <tensorflow/lite/kernels/internal/optimized/optimized_ops.h>
20
21 namespace luci_interpreter
22 {
23 namespace kernels
24 {
25
26 ArgMax::ArgMax(const Tensor *input, const Tensor *axis, Tensor *output, const ArgMaxParams &params)
27     : KernelWithParams<ArgMaxParams>({input, axis}, {output}, params)
28 {
29 }
30
31 void ArgMax::configure()
32 {
33   assert(axis()->element_type() == DataType::S32 || axis()->element_type() == DataType::S64);
34   assert(input()->shape().num_dims() >= 1);
35   const Shape &input_shape = input()->shape();
36   const int num_dims = input_shape.num_dims();
37   Shape output_shape(num_dims - 1);
38
39   // If axis value is negative, then update by adding input_shape's num_dims.
40   // If updated value also negative, then assert.
41   assert(axis()->shape().num_elements() == 1);
42   int axis_value = getTensorData<int32_t>(axis())[0];
43   if (axis_value < 0)
44     axis_value = axis_value + num_dims;
45   assert(axis_value >= 0);
46
47   int j = 0;
48   for (int i = 0; i < num_dims; i++)
49   {
50     if (i == axis_value)
51       continue;
52     output_shape.dim(j++) = input_shape.dim(i);
53   }
54
55   assert(output()->element_type() == _params.output_type);
56
57   output()->resize(output_shape);
58 }
59
60 void ArgMax::execute() const
61 {
62
63 #define TF_LITE_ARG_MAX(data_type, axis_type, output_type)                                     \
64   tflite::optimized_ops::ArgMinMax(getTensorShape(input()), getTensorData<data_type>(input()), \
65                                    getTensorData<axis_type>(axis()), getTensorShape(output()), \
66                                    getTensorData<output_type>(output()),                       \
67                                    std::greater<data_type>())
68   if (axis()->element_type() == DataType::S32)
69   {
70     switch (_params.output_type)
71     {
72       case DataType::S32:
73         switch (input()->element_type())
74         {
75           case DataType::FLOAT32:
76             TF_LITE_ARG_MAX(float, int32_t, int32_t);
77             break;
78           case DataType::U8:
79             TF_LITE_ARG_MAX(uint8_t, int32_t, int32_t);
80             break;
81           default:
82             throw std::runtime_error("Unsupported input type.");
83         }
84         break;
85       case DataType::S64:
86         switch (input()->element_type())
87         {
88           case DataType::FLOAT32:
89             TF_LITE_ARG_MAX(float, int32_t, int64_t);
90             break;
91           case DataType::U8:
92             TF_LITE_ARG_MAX(uint8_t, int32_t, int64_t);
93             break;
94           default:
95             throw std::runtime_error("Unsupported input type.");
96         }
97         break;
98       default:
99         throw std::runtime_error("Unsupported output type.");
100     }
101   }
102   else
103   {
104     switch (_params.output_type)
105     {
106       case DataType::S32:
107         switch (input()->element_type())
108         {
109           case DataType::FLOAT32:
110             TF_LITE_ARG_MAX(float, int64_t, int32_t);
111             break;
112           case DataType::U8:
113             TF_LITE_ARG_MAX(uint8_t, int64_t, int32_t);
114             break;
115           default:
116             throw std::runtime_error("Unsupported input type.");
117         }
118         break;
119       case DataType::S64:
120         switch (input()->element_type())
121         {
122           case DataType::FLOAT32:
123             TF_LITE_ARG_MAX(float, int64_t, int64_t);
124             break;
125           case DataType::U8:
126             TF_LITE_ARG_MAX(uint8_t, int64_t, int64_t);
127             break;
128           default:
129             throw std::runtime_error("Unsupported input type.");
130         }
131         break;
132       default:
133         throw std::runtime_error("Unsupported output type.");
134     }
135   }
136 #undef TF_LITE_ARG_MAX
137 }
138
139 } // namespace kernels
140 } // namespace luci_interpreter