2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "ArgMinMaxLayer.h"
19 #include "OperationUtils.h"
21 #include <cker/operation/ArgMinMax.h>
34 template <typename T> std::function<bool(T, T)> GetComparefunction(bool is_arg_max)
38 return std::greater<T>();
42 return std::less<T>();
47 void ArgMinMaxLayer::configure(const IPortableTensor *input, IPortableTensor *output,
48 const IPortableTensor *axis, bool is_arg_max)
53 _is_arg_max = is_arg_max;
56 void ArgMinMaxLayer::run()
58 if (_axis->total_size() != sizeof(int32_t))
60 throw std::runtime_error("ArgMinMax: wrong shape of axis");
62 auto axis = *reinterpret_cast<const int32_t *>(_axis->buffer());
65 axis += _input->num_dimensions();
67 #define TF_LITE_ARG_MIN_MAX(input_type, axis_type, output_type) \
68 ArgMinMax(getTensorShape(_input), reinterpret_cast<const input_type *>(_input->buffer()), \
69 getTensorShape(_output), reinterpret_cast<output_type *>(_output->buffer()), axis, \
70 GetComparefunction<input_type>(_is_arg_max));
71 if (_output->data_type() == ir::DataType::INT32)
73 switch (_input->data_type())
75 case ir::DataType::FLOAT32:
76 TF_LITE_ARG_MIN_MAX(float, int32_t, int32_t);
78 case ir::DataType::QUANT_UINT8_ASYMM:
79 case ir::DataType::UINT8:
80 TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int32_t);
82 case ir::DataType::QUANT_INT8_ASYMM:
83 TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int32_t);
85 case ir::DataType::INT32:
86 TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int32_t);
89 throw std::runtime_error("ArgMinMax: unsupported data type");
92 else if (_output->data_type() == ir::DataType::INT64)
94 switch (_input->data_type())
96 case ir::DataType::FLOAT32:
97 TF_LITE_ARG_MIN_MAX(float, int32_t, int64_t);
99 case ir::DataType::QUANT_UINT8_ASYMM:
100 case ir::DataType::UINT8:
101 TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int64_t);
103 case ir::DataType::QUANT_INT8_ASYMM:
104 TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int64_t);
106 case ir::DataType::INT32:
107 TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int64_t);
110 throw std::runtime_error("ArgMinMax: unsupported data type");
115 throw std::runtime_error("ArgMinMax: unsupported data type");
118 #undef TF_LITE_ARG_MIN_MAX
123 } // namespace backend