[tflite_loader] Implemented Mean node conversion (#7267)
authorIvan Vagin/AI Tools Lab /SRR/Engineer/삼성전자 <ivan.vagin@samsung.com>
Mon, 9 Sep 2019 01:45:41 +0000 (04:45 +0300)
committer이춘석/On-Device Lab(SR)/Staff Engineer/삼성전자 <chunseok.lee@samsung.com>
Mon, 9 Sep 2019 01:45:41 +0000 (10:45 +0900)
* [tflite_loader] Implemented Mean node conversion

* implemented Mean node conversion
* changed MeanNode keep_dims param type to bool

Signed-off-by: Ivan Vagin <ivan.vagin@samsung.com>
* Fixed formatting

Signed-off-by: Ivan Vagin <ivan.vagin@samsung.com>
runtimes/neurun/backend/acl_cl/KernelGenerator.cc
runtimes/neurun/backend/acl_neon/KernelGenerator.cc
runtimes/neurun/core/include/model/operation/MeanNode.h
runtimes/neurun/frontend/nnapi/wrapper/OperationFactory.cc
runtimes/neurun/frontend/tflite/loader.cc
runtimes/neurun/frontend/tflite/loader.h

index 5d3c9e1..cdc8948 100644 (file)
@@ -1879,8 +1879,8 @@ void KernelGenerator::visit(const model::operation::MeanNode &node)
   const auto ifm_index{node.getInputs().at(model::operation::MeanNode::Input::INPUT)};
 
   const auto axis_index{node.param().axis_index};
-  const auto keep_dims_index{node.param().keep_dims_index};
-  (void)keep_dims_index;
+  const auto keep_dims{node.param().keep_dims};
+  (void)keep_dims;
 
   const auto ifm_shape = _ctx.at(ifm_index).shape();
 
index f4eb91e..2bc6db2 100644 (file)
@@ -325,7 +325,7 @@ void KernelGenerator::visit(const model::operation::MeanNode &node)
   const auto ifm_index{node.getInputs().at(model::operation::MeanNode::Input::INPUT)};
 
   const auto axis_index{node.param().axis_index};
-  const auto keep_dims_index{node.param().keep_dims_index};
+  const auto keep_dims{node.param().keep_dims};
 
   const auto ifm_shape = _ctx.at(ifm_index).shape();
 
@@ -383,8 +383,6 @@ void KernelGenerator::visit(const model::operation::MeanNode &node)
     fixed_axis.set(fixed_axis.num_dimensions(), a);
   }
 
-  bool keep_dims = _ctx.at(keep_dims_index).asScalar<int32_t>() != 0;
-
   std::unique_ptr<::arm_compute::IFunction> fn;
 
   // NOTE NEReduceMean has a bug that does not support NHWC layout
index b6c9d95..9d14254 100644 (file)
@@ -37,7 +37,7 @@ public:
   struct Param
   {
     OperandIndex axis_index;
-    OperandIndex keep_dims_index;
+    bool keep_dims;
   };
 
 public:
index 89556b2..cd9d869 100644 (file)
@@ -1390,7 +1390,7 @@ OperationFactory::OperationFactory()
   };
 
   _map[ANEURALNETWORKS_MEAN] = [](const OperationFactory::Param &init_param,
-                                  neurun::model::Operands &) {
+                                  neurun::model::Operands &operands) {
     assert(init_param.input_count == 3 && init_param.output_count == 1);
 
     OperandIndexSequence outputs{init_param.outputs[0]};
@@ -1404,7 +1404,7 @@ OperationFactory::OperationFactory()
 
     operation::MeanNode::Param param;
     param.axis_index = OperandIndex{init_param.inputs[1]};
-    param.keep_dims_index = OperandIndex{init_param.inputs[2]};
+    param.keep_dims = operands.at(OperandIndex{init_param.inputs[2]}).asScalar<int32_t>() != 0;
 
     return new operation::MeanNode{inputs, outputs, param};
   };
index 58a2a30..cc1d522 100644 (file)
@@ -468,6 +468,24 @@ void Loader::loadTranspose(const tflite::Operator *op)
   _graph.addOperation(std::move(new_op));
 }
 
+void Loader::loadMean(const tflite::Operator *op)
+{
+  model::OperandIndexSequence inputs;
+  model::OperandIndexSequence outputs;
+
+  const auto input_index = (*op->inputs())[0];
+  inputs.append(model::OperandIndex(input_index));
+  const auto output_index = (*op->outputs())[0];
+  outputs.append(model::OperandIndex(output_index));
+
+  model::operation::MeanNode::Param param;
+  param.axis_index = model::OperandIndex((*op->inputs())[1]);
+  param.keep_dims = op->builtin_options_as_ReducerOptions()->keep_dims();
+
+  std::unique_ptr<model::Operation> new_op(new model::operation::MeanNode(inputs, outputs, param));
+  _graph.addOperation(std::move(new_op));
+}
+
 void Loader::loadCustom(const tflite::Operator *op)
 {
   model::OperandIndexSequence inputs;
@@ -565,6 +583,9 @@ void Loader::loadOperation(const tflite::Operator *op)
     case BuiltinOperator_TRANSPOSE:
       loadTranspose(op);
       return;
+    case BuiltinOperator_MEAN:
+      loadMean(op);
+      return;
     case BuiltinOperator_CUSTOM:
       loadCustom(op);
       return;
index f5fbefd..0ccd948 100644 (file)
@@ -93,6 +93,7 @@ private:
   void loadSquaredDifference(const tflite::Operator *op);
   void loadTanh(const tflite::Operator *op);
   void loadTranspose(const tflite::Operator *op);
+  void loadMean(const tflite::Operator *op);
 
   void loadCustom(const tflite::Operator *op);