[Pure ACL] Store Network Inputs/Ouputs (#543)
author박종현/동작제어Lab(SR)/Senior Engineer/삼성전자 <jh1302.park@samsung.com>
Tue, 10 Apr 2018 23:36:03 +0000 (08:36 +0900)
committer김정현/동작제어Lab(SR)/Senior Engineer/삼성전자 <jh0822.kim@samsung.com>
Tue, 10 Apr 2018 23:36:03 +0000 (08:36 +0900)
This commit revises tflite::Model and arm_comput::Model to store the
inputs and outputs of a given network.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
tools/nnapi_bindings/bindings/pure_arm_compute/src/compilation.cc
tools/nnapi_bindings/bindings/pure_arm_compute/src/internal/Model.h
tools/nnapi_bindings/bindings/pure_arm_compute/src/internal/arm_compute.h
tools/nnapi_bindings/bindings/pure_arm_compute/src/model.cc

index d046f8b..9a6ecd2 100644 (file)
@@ -18,5 +18,9 @@ ANeuralNetworksCompilation_create(ANeuralNetworksModel* model, ANeuralNetworksCo
 ResultCode
 ANeuralNetworksCompilation_finish(ANeuralNetworksCompilation* compilation)
 {
+  // Set 'inputs' and 'outputs'
+  compilation->output().inputs = compilation->model().inputs;
+  compilation->output().outputs = compilation->model().outputs;
+
   return ANEURALNETWORKS_NO_ERROR;
 }
index 0e5aeea..04b7cc0 100644 (file)
@@ -343,6 +343,11 @@ public:
 private:
   operand::Set _operands;
   op::Sequence _operations;
+
+public:
+  // TODO Hide these fields
+  std::vector<operand::Index> inputs;
+  std::vector<operand::Index> outputs;
 };
 
 } // namespace tflite
index 8ce555c..6e2467d 100644 (file)
@@ -114,6 +114,11 @@ public:
 private:
   operand::Context _operands;
   op::Sequence _ops;
+
+public:
+  // TODO Hide these fields
+  std::vector<tflite::operand::Index> inputs;
+  std::vector<tflite::operand::Index> outputs;
 };
 
 } // namepsace arm_compute
index c40f01d..63dbe9b 100644 (file)
@@ -120,6 +120,25 @@ ANeuralNetworksModel_identifyInputsAndOutputs(ANeuralNetworksModel* model,
                                               uint32_t inputCount, const uint32_t* inputs,
                                               uint32_t outputCount, const uint32_t* outputs)
 {
+  // NOTE ::internal::tflite::operand::Index uses int as its underlying type as various NNAPI
+  //      functions such as ANeuralNetworksModel_setOperandValue use int to represent operand index
+  //
+  //      ANeuralNetworksModel_identifyInputsAndOutputs, however, uses uint32_t to represent operand
+  //      index.
+  //
+  //      Below, static_cast<int>(...) is introduced to eliminate compiler warning.
+  for (uint32_t n = 0; n < inputCount; ++n)
+  {
+    const ::internal::tflite::operand::Index ind{static_cast<int>(inputs[n])};
+    model->deref().inputs.emplace_back(ind);
+  }
+
+  for (uint32_t n = 0; n < outputCount; ++n)
+  {
+    const ::internal::tflite::operand::Index ind{static_cast<int>(outputs[n])};
+    model->deref().outputs.emplace_back(ind);
+  }
+
   return ANEURALNETWORKS_NO_ERROR;
 }