Add calls to CPU version of SimpleArgMinMax op (#4029)
authorShubham Gupta/SNAP /SRI-Bangalore/Engineer/삼성전자 <shub98.gupta@samsung.com>
Tue, 18 Dec 2018 07:07:43 +0000 (12:37 +0530)
committer박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Tue, 18 Dec 2018 07:07:43 +0000 (16:07 +0900)
This patch will add call to CPU version of ArgMinMax op
and changing SimpleArgMinMax files to use ArgOperation enum

Signed-off-by: shubham <shub98.gupta@samsung.com>
runtimes/pure_arm_compute/src/compilation.cc
runtimes/pure_arm_compute/src/internal/layers/SimpleArgMinMax.cc
runtimes/pure_arm_compute/src/internal/layers/SimpleArgMinMax.h

index 5f7bfb9..e452990 100644 (file)
@@ -3766,15 +3766,30 @@ void Planner::visit(const ::internal::tflite::op::ArgMax::Node &node)
     auto ofm_alloc = ctx.at(::internal::tflite::operand::Index{param.ofm_index});
     auto ifm_alloc = ctx.at(::internal::tflite::operand::Index{param.ifm_index});
 
-    if (::internal::arm_compute::isGpuMode())
+    if (from_env<bool>(std::getenv("USE_SIMPLE_ARGMINMAX")))
     {
-      auto fn = nnfw::cpp14::make_unique<::arm_compute::CLArgMinMax>();
-      fn->configure(CAST_CL(ifm_alloc), CAST_CL(ofm_alloc), param.axis,
-                    ::arm_compute::ArgOperation::MAX);
+      // USE CPU VERSION OF ARGMAX
+      auto fn = nnfw::cpp14::make_unique<SimpleArgMinMax>();
+
+      fn->configure(ifm_alloc, ofm_alloc, param.axis, ::arm_compute::ArgOperation::MAX);
+
       builder.append("ArgMax", std::move(fn));
     }
     else
-      throw std::runtime_error("Not supported, yet");
+    {
+
+      if (::internal::arm_compute::isGpuMode())
+      {
+        auto fn = nnfw::cpp14::make_unique<::arm_compute::CLArgMinMax>();
+
+        fn->configure(CAST_CL(ifm_alloc), CAST_CL(ofm_alloc), param.axis,
+                      ::arm_compute::ArgOperation::MAX);
+
+        builder.append("ArgMax", std::move(fn));
+      }
+      else
+        throw std::runtime_error("Not supported, yet");
+    }
   };
 
   _builder.addStage(stage);
index 4df9d37..6d348e8 100644 (file)
 #include <arm_compute/runtime/CL/CLScheduler.h>
 
 void SimpleArgMinMax::configure(::arm_compute::ITensor *input, ::arm_compute::ITensor *output,
-                                std::vector<uint32_t> axis, int rank, bool is_min, bool is_max)
+                                std::vector<uint32_t> axis, ::arm_compute::ArgOperation op)
 {
   _input = input;
   _output = output;
   _axis = axis;
-  _input_rank = rank;
-  _is_min = is_min;
-  _is_max = is_max;
+  _input_rank = input->info()->num_dimensions();
+  _op_type = op;
 }
 
 inline const ::arm_compute::TensorShape
@@ -49,10 +48,11 @@ inferOutputShape(const ::arm_compute::TensorShape &input_shape, const std::vecto
 }
 
 template <typename T>
-inline T
-getArgMinMaxEle(const ::arm_compute::ITensor *input, const ::arm_compute::TensorShape &input_shape,
-                const ::arm_compute::TensorShape &output_shape, const size_t b, const size_t d,
-                const size_t h, const size_t w, const int axis, bool is_min, bool is_max)
+inline T getArgMinMaxEle(const ::arm_compute::ITensor *input,
+                         const ::arm_compute::TensorShape &input_shape,
+                         const ::arm_compute::TensorShape &output_shape, const size_t b,
+                         const size_t d, const size_t h, const size_t w, const int axis,
+                         const ::arm_compute::ArgOperation op_type)
 {
   // If output[dimention] == 1, will check all values of that dimension because of reducing
   // dimension.
@@ -84,14 +84,17 @@ getArgMinMaxEle(const ::arm_compute::ITensor *input, const ::arm_compute::Tensor
         for (size_t in_w = start_w; in_w <= stop_w; ++in_w)
         {
           id.set(0, in_w);
-          if (is_min)
+          if (op_type == ::arm_compute::ArgOperation::MIN)
           {
             value = std::min<T>(value, *reinterpret_cast<T *>(input->ptr_to_element(id)));
           }
-          else if (is_max)
+          else if (op_type == ::arm_compute::ArgOperation::MAX)
           {
             value = std::max<T>(value, *reinterpret_cast<T *>(input->ptr_to_element(id)));
           }
+          else
+            throw std::runtime_error("This Arg operation is not supported, yet");
+
           if (tval != value)
           {
             min_max_id = id;
@@ -106,10 +109,10 @@ getArgMinMaxEle(const ::arm_compute::ITensor *input, const ::arm_compute::Tensor
 }
 
 template <typename T>
-inline void getArgMinMax(const ::arm_compute::ITensor *input,
-                         const ::arm_compute::TensorShape &input_shape,
-                         const ::arm_compute::TensorShape &output_shape,
-                         ::arm_compute::ITensor *output, const int axis, bool is_min, bool is_max)
+inline void
+getArgMinMax(const ::arm_compute::ITensor *input, const ::arm_compute::TensorShape &input_shape,
+             const ::arm_compute::TensorShape &output_shape, ::arm_compute::ITensor *output,
+             const int axis, const ::arm_compute::ArgOperation op_type)
 {
   ::arm_compute::Coordinates id;
   for (size_t out_b = 0; out_b < output_shape[3]; ++out_b)
@@ -125,7 +128,7 @@ inline void getArgMinMax(const ::arm_compute::ITensor *input,
         {
           id.set(0, out_w);
           *reinterpret_cast<int *>(output->ptr_to_element(id)) = getArgMinMaxEle<T>(
-              input, input_shape, output_shape, out_b, out_d, out_h, out_w, axis, is_min, is_max);
+              input, input_shape, output_shape, out_b, out_d, out_h, out_w, axis, op_type);
         }
       }
     }
@@ -153,13 +156,13 @@ void SimpleArgMinMax::run()
   switch (_input->info()->data_type())
   {
     case ::arm_compute::DataType::QASYMM8:
-      getArgMinMax<uint8_t>(_input, input_shape, output_shape, _output, axis_val, _is_min, _is_max);
+      getArgMinMax<uint8_t>(_input, input_shape, output_shape, _output, axis_val, _op_type);
       break;
     case ::arm_compute::DataType::S32:
-      getArgMinMax<int32_t>(_input, input_shape, output_shape, _output, axis_val, _is_min, _is_max);
+      getArgMinMax<int32_t>(_input, input_shape, output_shape, _output, axis_val, _op_type);
       break;
     case ::arm_compute::DataType::F32:
-      getArgMinMax<float>(_input, input_shape, output_shape, _output, axis_val, _is_min, _is_max);
+      getArgMinMax<float>(_input, input_shape, output_shape, _output, axis_val, _op_type);
       break;
     default:
       ARM_COMPUTE_ERROR("DataType not supported");
index 91b46c2..b90e745 100644 (file)
 #define __SIMPLE_ARG_MIN_MAX_H__
 
 #include "internal/arm_compute.h"
+#include "arm_compute/core/TypesEx.h"
 
 class SimpleArgMinMax : public ::arm_compute::IFunction
 {
 public:
-  SimpleArgMinMax(void)
-      : _input(nullptr), _output(nullptr), _axis(), _input_rank(0), _is_min(false), _is_max(false)
+  SimpleArgMinMax(void) : _input(nullptr), _output(nullptr), _axis(), _input_rank(0)
   {
     // DO NOTHING
   }
@@ -34,13 +34,9 @@ public:
    * @param[in]  input       First tensor input.
    * @param[out] output      Output tensor.
    * @param[in]  axis        Dimension along which to find Min or Max Index.
-   * @param[in]  input_rank  Rank of input tensor.
-   * @param[in]  is_min      True for ArgMin.
-   * @param[in]  is_max      True for ArgMax.
    */
-
   void configure(::arm_compute::ITensor *input, ::arm_compute::ITensor *output,
-                 std::vector<uint32_t> axis, int rank, bool is_min, bool is_max);
+                 std::vector<uint32_t> axis, ::arm_compute::ArgOperation _op_type);
 
   void run() override;
 
@@ -49,8 +45,7 @@ private:
   ::arm_compute::ITensor *_output;
   std::vector<uint32_t> _axis;
   int _input_rank;
-  bool _is_min;
-  bool _is_max;
+  ::arm_compute::ArgOperation _op_type;
 };
 
 #endif /*__SIMPLE_ARG_MIN_MAX_H__ */