2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "ArgMinMaxLayer.h"
19 #include "OperationUtils.h"
21 #include <cker/operation/ArgMinMax.h>
34 template <typename T> std::function<bool(T, T)> GetComparefunction(bool is_arg_max)
38 return std::greater<T>();
42 return std::less<T>();
47 void ArgMinMaxLayer::configure(const IPortableTensor *input, IPortableTensor *output, int32_t axis,
54 axis += input->num_dimensions();
57 _is_arg_max = is_arg_max;
60 void ArgMinMaxLayer::run()
62 #define TF_LITE_ARG_MIN_MAX(input_type, axis_type, output_type) \
63 ArgMinMax(getTensorShape(_input), reinterpret_cast<const input_type *>(_input->buffer()), \
64 getTensorShape(_output), reinterpret_cast<output_type *>(_output->buffer()), _axis, \
65 GetComparefunction<input_type>(_is_arg_max));
66 if (_output->data_type() == ir::DataType::INT32)
68 switch (_input->data_type())
70 case ir::DataType::FLOAT32:
71 TF_LITE_ARG_MIN_MAX(float, int32_t, int32_t);
73 case ir::DataType::QUANT_UINT8_ASYMM:
74 case ir::DataType::UINT8:
75 TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int32_t);
77 case ir::DataType::INT32:
78 TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int32_t);
81 throw std::runtime_error("ArgMinMax: unsupported data type");
84 else if (_output->data_type() == ir::DataType::INT64)
86 switch (_input->data_type())
88 case ir::DataType::FLOAT32:
89 TF_LITE_ARG_MIN_MAX(float, int32_t, int64_t);
91 case ir::DataType::QUANT_UINT8_ASYMM:
92 case ir::DataType::UINT8:
93 TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int64_t);
95 case ir::DataType::INT32:
96 TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int64_t);
99 throw std::runtime_error("ArgMinMax: unsupported data type");
104 throw std::runtime_error("ArgMinMax: unsupported data type");
107 #undef TF_LITE_ARG_MIN_MAX
112 } // namespace backend