Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / runtime / onert / backend / cpu / ops / ArgMinMaxLayer.cc
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "ArgMinMaxLayer.h"
18
19 #include "OperationUtils.h"
20
21 #include <cker/operation/ArgMinMax.h>
22 #include <assert.h>
23
24 namespace onert
25 {
26 namespace backend
27 {
28 namespace cpu
29 {
30 namespace ops
31 {
32 namespace
33 {
34 template <typename T> std::function<bool(T, T)> GetComparefunction(bool is_arg_max)
35 {
36   if (is_arg_max)
37   {
38     return std::greater<T>();
39   }
40   else
41   {
42     return std::less<T>();
43   }
44 }
45 }
46
47 void ArgMinMaxLayer::configure(const IPortableTensor *input, IPortableTensor *output, int32_t axis,
48                                bool is_arg_max)
49 {
50   _input = input;
51   _output = output;
52   if (axis < 0)
53   {
54     axis += input->num_dimensions();
55   }
56   _axis = axis;
57   _is_arg_max = is_arg_max;
58 }
59
60 void ArgMinMaxLayer::run()
61 {
62 #define TF_LITE_ARG_MIN_MAX(input_type, axis_type, output_type)                                 \
63   ArgMinMax(getTensorShape(_input), reinterpret_cast<const input_type *>(_input->buffer()),     \
64             getTensorShape(_output), reinterpret_cast<output_type *>(_output->buffer()), _axis, \
65             GetComparefunction<input_type>(_is_arg_max));
66   if (_output->data_type() == ir::DataType::INT32)
67   {
68     switch (_input->data_type())
69     {
70       case ir::DataType::FLOAT32:
71         TF_LITE_ARG_MIN_MAX(float, int32_t, int32_t);
72         break;
73       case ir::DataType::QUANT_UINT8_ASYMM:
74       case ir::DataType::UINT8:
75         TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int32_t);
76         break;
77       case ir::DataType::INT32:
78         TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int32_t);
79         break;
80       default:
81         throw std::runtime_error("ArgMinMax: unsupported data type");
82     }
83   }
84   else if (_output->data_type() == ir::DataType::INT64)
85   {
86     switch (_input->data_type())
87     {
88       case ir::DataType::FLOAT32:
89         TF_LITE_ARG_MIN_MAX(float, int32_t, int64_t);
90         break;
91       case ir::DataType::QUANT_UINT8_ASYMM:
92       case ir::DataType::UINT8:
93         TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int64_t);
94         break;
95       case ir::DataType::INT32:
96         TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int64_t);
97         break;
98       default:
99         throw std::runtime_error("ArgMinMax: unsupported data type");
100     }
101   }
102   else
103   {
104     throw std::runtime_error("ArgMinMax: unsupported data type");
105   }
106
107 #undef TF_LITE_ARG_MIN_MAX
108 }
109
110 } // namespace ops
111 } // namespace cpu
112 } // namespace backend
113 } // namespace onert