Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / pal / common / PALArgMinMax.h
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef LUCI_INTERPRETER_PAL_ARG_MIN_MAX_H
19 #define LUCI_INTERPRETER_PAL_ARG_MIN_MAX_H
20
21 #include "Params.h"
22 #include "PALUtils.h"
23
24 namespace luci_interpreter_pal
25 {
26
27 template <typename T1, typename T2, typename T3, typename Cmp>
28 void ArgMinMax(const luci_interpreter::RuntimeShape &input1_shape, const T1 *input1_data,
29                const T3 *input2_data, const luci_interpreter::RuntimeShape &output_shape,
30                T2 *output_data, const Cmp &cmp)
31 {
32   int axis = input2_data[0];
33   if (axis < 0)
34   {
35     axis += input1_shape.dimensionsCount();
36   }
37   const int axis_size = input1_shape.dims(axis);
38
39   int outer_size = 1;
40   for (int i = 0; i < axis; ++i)
41   {
42     outer_size *= input1_shape.dims(i);
43   }
44
45   int inner_size = 1;
46   const int dims_count = input1_shape.dimensionsCount();
47   for (int i = axis + 1; i < dims_count; ++i)
48   {
49     inner_size *= input1_shape.dims(i);
50   }
51   for (int outer = 0; outer < outer_size; ++outer)
52   {
53     for (int inner = 0; inner < inner_size; ++inner)
54     {
55       auto min_max_value = input1_data[outer * axis_size * inner_size + inner];
56       T2 min_max_index = 0;
57       for (int i = 1; i < axis_size; ++i)
58       {
59         const auto &curr_value = input1_data[(outer * axis_size + i) * inner_size + inner];
60         if (cmp(curr_value, min_max_value))
61         {
62           min_max_value = curr_value;
63           min_max_index = static_cast<T2>(i);
64         }
65       }
66       output_data[outer * inner_size + inner] = min_max_index;
67     }
68   }
69 }
70
71 } // namespace luci_interpreter_pal
72
73 #endif // LUCI_INTERPRETER_PAL_ARG_MIN_MAX_H