2fd284c915a5c4813e8ffa09d6959554b6e381f9
[platform/core/ml/nnfw.git] / runtime / onert / backend / cpu / ops / ArgMinMaxLayer.cc
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "ArgMinMaxLayer.h"
18
19 #include "OperationUtils.h"
20
21 #include <cker/operation/ArgMinMax.h>
22 #include <assert.h>
23
24 namespace onert
25 {
26 namespace backend
27 {
28 namespace cpu
29 {
30 namespace ops
31 {
32 namespace
33 {
34 template <typename T> std::function<bool(T, T)> GetComparefunction(bool is_arg_max)
35 {
36   if (is_arg_max)
37   {
38     return std::greater<T>();
39   }
40   else
41   {
42     return std::less<T>();
43   }
44 }
45 }
46
47 void ArgMinMaxLayer::configure(const IPortableTensor *input, IPortableTensor *output,
48                                const IPortableTensor *axis, bool is_arg_max)
49 {
50   _input = input;
51   _output = output;
52   _axis = axis;
53   _is_arg_max = is_arg_max;
54 }
55
56 void ArgMinMaxLayer::run()
57 {
58   if (_axis->total_size() != sizeof(int32_t))
59   {
60     throw std::runtime_error("ArgMinMax: wrong shape of axis");
61   }
62   auto axis = *reinterpret_cast<const int32_t *>(_axis->buffer());
63   if (axis < 0)
64   {
65     axis += _input->num_dimensions();
66   }
67 #define TF_LITE_ARG_MIN_MAX(input_type, axis_type, output_type)                                \
68   ArgMinMax(getTensorShape(_input), reinterpret_cast<const input_type *>(_input->buffer()),    \
69             getTensorShape(_output), reinterpret_cast<output_type *>(_output->buffer()), axis, \
70             GetComparefunction<input_type>(_is_arg_max));
71   if (_output->data_type() == ir::DataType::INT32)
72   {
73     switch (_input->data_type())
74     {
75       case ir::DataType::FLOAT32:
76         TF_LITE_ARG_MIN_MAX(float, int32_t, int32_t);
77         break;
78       case ir::DataType::QUANT_UINT8_ASYMM:
79       case ir::DataType::UINT8:
80         TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int32_t);
81         break;
82       case ir::DataType::INT32:
83         TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int32_t);
84         break;
85       default:
86         throw std::runtime_error("ArgMinMax: unsupported data type");
87     }
88   }
89   else if (_output->data_type() == ir::DataType::INT64)
90   {
91     switch (_input->data_type())
92     {
93       case ir::DataType::FLOAT32:
94         TF_LITE_ARG_MIN_MAX(float, int32_t, int64_t);
95         break;
96       case ir::DataType::QUANT_UINT8_ASYMM:
97       case ir::DataType::UINT8:
98         TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int64_t);
99         break;
100       case ir::DataType::INT32:
101         TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int64_t);
102         break;
103       default:
104         throw std::runtime_error("ArgMinMax: unsupported data type");
105     }
106   }
107   else
108   {
109     throw std::runtime_error("ArgMinMax: unsupported data type");
110   }
111
112 #undef TF_LITE_ARG_MIN_MAX
113 }
114
115 } // namespace ops
116 } // namespace cpu
117 } // namespace backend
118 } // namespace onert