auto ofm_alloc = ctx.at(::internal::tflite::operand::Index{param.ofm_index});
auto ifm_alloc = ctx.at(::internal::tflite::operand::Index{param.ifm_index});
- if (::internal::arm_compute::isGpuMode())
+ if (from_env<bool>(std::getenv("USE_SIMPLE_ARGMINMAX")))
{
- auto fn = nnfw::cpp14::make_unique<::arm_compute::CLArgMinMax>();
- fn->configure(CAST_CL(ifm_alloc), CAST_CL(ofm_alloc), param.axis,
- ::arm_compute::ArgOperation::MAX);
+ // USE CPU VERSION OF ARGMAX
+ auto fn = nnfw::cpp14::make_unique<SimpleArgMinMax>();
+
+ fn->configure(ifm_alloc, ofm_alloc, param.axis, ::arm_compute::ArgOperation::MAX);
+
builder.append("ArgMax", std::move(fn));
}
else
- throw std::runtime_error("Not supported, yet");
+ {
+
+ if (::internal::arm_compute::isGpuMode())
+ {
+ auto fn = nnfw::cpp14::make_unique<::arm_compute::CLArgMinMax>();
+
+ fn->configure(CAST_CL(ifm_alloc), CAST_CL(ofm_alloc), param.axis,
+ ::arm_compute::ArgOperation::MAX);
+
+ builder.append("ArgMax", std::move(fn));
+ }
+ else
+ throw std::runtime_error("Not supported, yet");
+ }
};
_builder.addStage(stage);
#include <arm_compute/runtime/CL/CLScheduler.h>
void SimpleArgMinMax::configure(::arm_compute::ITensor *input, ::arm_compute::ITensor *output,
- std::vector<uint32_t> axis, int rank, bool is_min, bool is_max)
+ std::vector<uint32_t> axis, ::arm_compute::ArgOperation op)
{
_input = input;
_output = output;
_axis = axis;
- _input_rank = rank;
- _is_min = is_min;
- _is_max = is_max;
+ _input_rank = input->info()->num_dimensions();
+ _op_type = op;
}
inline const ::arm_compute::TensorShape
}
template <typename T>
-inline T
-getArgMinMaxEle(const ::arm_compute::ITensor *input, const ::arm_compute::TensorShape &input_shape,
- const ::arm_compute::TensorShape &output_shape, const size_t b, const size_t d,
- const size_t h, const size_t w, const int axis, bool is_min, bool is_max)
+inline T getArgMinMaxEle(const ::arm_compute::ITensor *input,
+ const ::arm_compute::TensorShape &input_shape,
+ const ::arm_compute::TensorShape &output_shape, const size_t b,
+ const size_t d, const size_t h, const size_t w, const int axis,
+ const ::arm_compute::ArgOperation op_type)
{
// If output[dimention] == 1, will check all values of that dimension because of reducing
// dimension.
for (size_t in_w = start_w; in_w <= stop_w; ++in_w)
{
id.set(0, in_w);
- if (is_min)
+ if (op_type == ::arm_compute::ArgOperation::MIN)
{
value = std::min<T>(value, *reinterpret_cast<T *>(input->ptr_to_element(id)));
}
- else if (is_max)
+ else if (op_type == ::arm_compute::ArgOperation::MAX)
{
value = std::max<T>(value, *reinterpret_cast<T *>(input->ptr_to_element(id)));
}
+ else
+ throw std::runtime_error("This Arg operation is not supported, yet");
+
if (tval != value)
{
min_max_id = id;
}
template <typename T>
-inline void getArgMinMax(const ::arm_compute::ITensor *input,
- const ::arm_compute::TensorShape &input_shape,
- const ::arm_compute::TensorShape &output_shape,
- ::arm_compute::ITensor *output, const int axis, bool is_min, bool is_max)
+inline void
+getArgMinMax(const ::arm_compute::ITensor *input, const ::arm_compute::TensorShape &input_shape,
+ const ::arm_compute::TensorShape &output_shape, ::arm_compute::ITensor *output,
+ const int axis, const ::arm_compute::ArgOperation op_type)
{
::arm_compute::Coordinates id;
for (size_t out_b = 0; out_b < output_shape[3]; ++out_b)
{
id.set(0, out_w);
*reinterpret_cast<int *>(output->ptr_to_element(id)) = getArgMinMaxEle<T>(
- input, input_shape, output_shape, out_b, out_d, out_h, out_w, axis, is_min, is_max);
+ input, input_shape, output_shape, out_b, out_d, out_h, out_w, axis, op_type);
}
}
}
switch (_input->info()->data_type())
{
case ::arm_compute::DataType::QASYMM8:
- getArgMinMax<uint8_t>(_input, input_shape, output_shape, _output, axis_val, _is_min, _is_max);
+ getArgMinMax<uint8_t>(_input, input_shape, output_shape, _output, axis_val, _op_type);
break;
case ::arm_compute::DataType::S32:
- getArgMinMax<int32_t>(_input, input_shape, output_shape, _output, axis_val, _is_min, _is_max);
+ getArgMinMax<int32_t>(_input, input_shape, output_shape, _output, axis_val, _op_type);
break;
case ::arm_compute::DataType::F32:
- getArgMinMax<float>(_input, input_shape, output_shape, _output, axis_val, _is_min, _is_max);
+ getArgMinMax<float>(_input, input_shape, output_shape, _output, axis_val, _op_type);
break;
default:
ARM_COMPUTE_ERROR("DataType not supported");
#define __SIMPLE_ARG_MIN_MAX_H__
#include "internal/arm_compute.h"
+#include "arm_compute/core/TypesEx.h"
class SimpleArgMinMax : public ::arm_compute::IFunction
{
public:
- SimpleArgMinMax(void)
- : _input(nullptr), _output(nullptr), _axis(), _input_rank(0), _is_min(false), _is_max(false)
+ SimpleArgMinMax(void) : _input(nullptr), _output(nullptr), _axis(), _input_rank(0)
{
// DO NOTHING
}
* @param[in] input First tensor input.
* @param[out] output Output tensor.
* @param[in] axis Dimension along which to find Min or Max Index.
- * @param[in] input_rank Rank of input tensor.
- * @param[in] is_min True for ArgMin.
- * @param[in] is_max True for ArgMax.
*/
-
void configure(::arm_compute::ITensor *input, ::arm_compute::ITensor *output,
- std::vector<uint32_t> axis, int rank, bool is_min, bool is_max);
+ std::vector<uint32_t> axis, ::arm_compute::ArgOperation _op_type);
void run() override;
::arm_compute::ITensor *_output;
std::vector<uint32_t> _axis;
int _input_rank;
- bool _is_min;
- bool _is_max;
+ ::arm_compute::ArgOperation _op_type;
};
#endif /*__SIMPLE_ARG_MIN_MAX_H__ */