Add NNAPI delegation for EMBEDING_LOOKUP, RNN, SVDF
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 23 May 2018 21:33:59 +0000 (14:33 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 23 May 2018 21:36:46 +0000 (14:36 -0700)
PiperOrigin-RevId: 197790679

tensorflow/contrib/lite/nnapi_delegate.cc
tensorflow/contrib/lite/nnapi_delegate.h

index 107c84e..eed57d4 100644 (file)
@@ -155,7 +155,6 @@ uint32_t addTensorOperands(tflite::Interpreter* interpreter,
         nn_type, static_cast<uint32_t>(tensor->dims->size),
         reinterpret_cast<uint32_t*>(tensor->dims->data), scale, zeroPoint};
     CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type));
-
     // TODO(aselle): Based on Michael's suggestion, limiting this to read
     // only memory
     if (tensor->allocation_type == kTfLiteMmapRo) {
@@ -168,7 +167,12 @@ uint32_t addTensorOperands(tflite::Interpreter* interpreter,
         CHECK_NN(ANeuralNetworksModel_setOperandValue(
             nn_model, next_id, tensor->data.raw, tensor->bytes));
       }
+    } else if (tensor->bytes == 0) {
+      // These size 0 tensors are optional tensors reserved.
+      CHECK_NN(
+          ANeuralNetworksModel_setOperandValue(nn_model, next_id, nullptr, 0));
     }
+
     ++next_id;
   }
   return next_id;
@@ -177,7 +181,9 @@ uint32_t addTensorOperands(tflite::Interpreter* interpreter,
 // Adds the operations and their parameters to the NN API model.
 // 'next-id' is the operand ID of the next operand of the model.
 void AddOpsAndParams(tflite::Interpreter* interpreter,
-                     ANeuralNetworksModel* nn_model, uint32_t next_id) {
+                     ANeuralNetworksModel* nn_model, uint32_t next_id,
+                     std::vector<int>* model_state_inputs,
+                     std::vector<int>* model_state_outputs) {
   for (size_t i = 0; i < interpreter->nodes_size(); i++) {
     const auto* node_and_registration = interpreter->node_and_registration(i);
     const TfLiteNode& node = node_and_registration->first;
@@ -188,6 +194,8 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
     // Add the parameters.
     std::vector<uint32_t> augmented_inputs(
         node.inputs->data, node.inputs->data + node.inputs->size);
+    std::vector<uint32_t> augmented_outputs(
+        node.outputs->data, node.outputs->data + node.outputs->size);
 
     auto add_scalar_int32 = [&nn_model, &augmented_inputs,
                              &next_id](int value) {
@@ -207,12 +215,23 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
       augmented_inputs.push_back(next_id++);
     };
 
+    // Handle state tensors of RNN, LSTM, SVDF.
+    // For each state_out tensor, a corresponding state_in operand needs to be
+    // created for NNAPI.
     auto duplicate_state_tensor_float32 =
-        [interpreter, &nn_model, &augmented_inputs](int tensor_id) {
+        [interpreter, &nn_model, &next_id, &augmented_inputs,
+         &model_state_inputs, &model_state_outputs](int tensor_id) {
           const TfLiteTensor* tensor = interpreter->tensor(tensor_id);
-          CHECK_NN(ANeuralNetworksModel_setOperandValue(
-              nn_model, tensor_id, tensor->data.raw, tensor->bytes));
-          augmented_inputs.push_back(tensor_id);
+          ANeuralNetworksOperandType operand_type{
+              ANEURALNETWORKS_TENSOR_FLOAT32,
+              static_cast<uint32_t>(tensor->dims->size),
+              reinterpret_cast<uint32_t*>(tensor->dims->data),
+              tensor->params.scale, tensor->params.zero_point};
+          CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type));
+          augmented_inputs.push_back(next_id);
+          model_state_inputs->push_back(next_id);
+          model_state_outputs->push_back(tensor_id);
+          next_id++;
         };
 
     auto add_add_params = [&add_scalar_int32]() { add_scalar_int32(0); };
@@ -275,28 +294,51 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
       add_scalar_float32(builtin->proj_clip);
     };
 
+    // LSTM in NNAPI requires scratch tensor as an output operand.
+    auto add_lstm_scratch_tensor_float32 = [interpreter, &node, &nn_model,
+                                            &next_id, &augmented_outputs]() {
+      int scratch_buffer_index = node.temporaries->data[0];
+      const TfLiteTensor* tensor = interpreter->tensor(scratch_buffer_index);
+      ANeuralNetworksOperandType operand_type{
+          ANEURALNETWORKS_TENSOR_FLOAT32,
+          static_cast<uint32_t>(tensor->dims->size),
+          reinterpret_cast<uint32_t*>(tensor->dims->data), tensor->params.scale,
+          tensor->params.zero_point};
+      CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type));
+      augmented_outputs.insert(augmented_outputs.begin(), next_id++);
+    };
+
     auto add_mean_params = [&add_scalar_int32](void* data) {
       auto builtin = reinterpret_cast<TfLiteMeanParams*>(data);
       add_scalar_int32(builtin->keep_dims);
     };
 
-#if 0
-    auto add_reshape_params = [&](void* data) {
-      auto builtin = reinterpret_cast<TfLiteReshapeParams*>(data);
-      uint32_t tensor_size_shape = builtin->num_dimensions;
-      ANeuralNetworksOperandType operand_type{
-          ANEURALNETWORKS_TENSOR_INT32,
-          {static_cast<uint32_t>(1),
-           reinterpret_cast<uint32_t*>(&tensor_size_shape)},
-          0,
-          0};
-      CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type))
-      CHECK_NN(ANeuralNetworksModel_setOperandValue(
-          nn_model, next_id, builtin->shape,
-          sizeof(int) * builtin->num_dimensions));
-      augmented_inputs.push_back(next_id++);
+    auto add_svdf_params = [&add_scalar_int32](void* data) {
+      auto builtin = reinterpret_cast<TfLiteSVDFParams*>(data);
+      add_scalar_int32(builtin->rank);
+      add_scalar_int32(builtin->activation);
     };
-#endif
+
+    auto add_rnn_params = [&add_scalar_int32](void* data) {
+      auto builtin = reinterpret_cast<TfLiteRNNParams*>(data);
+      add_scalar_int32(builtin->activation);
+    };
+
+    // Handle optional input tensors.
+    auto add_optional_tensors = [&nn_model, &augmented_inputs,
+                                 &next_id](int nn_type) {
+      for (size_t idx = 0; idx < augmented_inputs.size(); idx++) {
+        if (augmented_inputs[idx] == kOptionalTensor) {
+          const std::vector<uint32_t> dim = {0, 0};
+          ANeuralNetworksOperandType operand_type{nn_type, 2, dim.data(), 0, 0};
+          CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type))
+          CHECK_NN(ANeuralNetworksModel_setOperandValue(nn_model, next_id,
+                                                        nullptr, 0))
+          augmented_inputs[idx] = next_id++;
+        }
+      }
+    };
+
     int nnapi_version = 10;
     ANeuralNetworksOperationType nn_op_type;
 
@@ -366,13 +408,31 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
         break;
       case tflite::BuiltinOperator_LSTM: {
         duplicate_state_tensor_float32(
-            node.outputs->data[/*kOutputStateTensor*/ 1]);
+            node.outputs->data[/*kOutputStateTensor*/ 0]);
         duplicate_state_tensor_float32(
-            node.outputs->data[/*kCellStateTensor*/ 2]);
+            node.outputs->data[/*kCellStateTensor*/ 1]);
         add_lstm_params(node.builtin_data);
+        add_lstm_scratch_tensor_float32();
+        add_optional_tensors(ANEURALNETWORKS_TENSOR_FLOAT32);
         nn_op_type = ANEURALNETWORKS_LSTM;
         break;
       }
+      case tflite::BuiltinOperator_SVDF: {
+        duplicate_state_tensor_float32(node.outputs->data[/*kStateTensor*/ 0]);
+        add_svdf_params(node.builtin_data);
+        nn_op_type = ANEURALNETWORKS_SVDF;
+        break;
+      }
+      case tflite::BuiltinOperator_RNN: {
+        duplicate_state_tensor_float32(
+            node.outputs->data[/*kHiddenStateTensor*/ 0]);
+        add_rnn_params(node.builtin_data);
+        nn_op_type = ANEURALNETWORKS_RNN;
+        break;
+      }
+      case tflite::BuiltinOperator_EMBEDDING_LOOKUP:
+        nn_op_type = ANEURALNETWORKS_EMBEDDING_LOOKUP;
+        break;
       case tflite::BuiltinOperator_PAD:
         nnapi_version = 11;  // require NNAPI 1.1
         nn_op_type = ANEURALNETWORKS_PAD;
@@ -392,12 +452,9 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
         break;
       case tflite::BuiltinOperator_CONCAT_EMBEDDINGS:
       case tflite::BuiltinOperator_LSH_PROJECTION:
-      case tflite::BuiltinOperator_SVDF:
       case tflite::BuiltinOperator_HASHTABLE_LOOKUP:
-      case tflite::BuiltinOperator_RNN:
       case tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN:
       case tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN:
-      case tflite::BuiltinOperator_EMBEDDING_LOOKUP:
       case tflite::BuiltinOperator_EMBEDDING_LOOKUP_SPARSE:
       case tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM:
       case tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
@@ -450,8 +507,9 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
     // Add the operation.
     CHECK_NN(ANeuralNetworksModel_addOperation(
         nn_model, nn_op_type, static_cast<uint32_t>(augmented_inputs.size()),
-        augmented_inputs.data(), static_cast<uint32_t>(node.outputs->size),
-        reinterpret_cast<uint32_t*>(node.outputs->data)));
+        augmented_inputs.data(),
+        static_cast<uint32_t>(augmented_outputs.size()),
+        reinterpret_cast<uint32_t*>(augmented_outputs.data())));
   }
 }
 
@@ -475,12 +533,25 @@ TfLiteStatus NNAPIDelegate::BuildGraph(Interpreter* interpreter) {
     }
 
     uint32_t next_id = addTensorOperands(interpreter, nn_model_, skip_list);
-    AddOpsAndParams(interpreter, nn_model_, next_id);
+    AddOpsAndParams(interpreter, nn_model_, next_id, &model_states_inputs_,
+                    &model_states_outputs_);
+
+    std::vector<int> augmented_inputs = interpreter->inputs();
+    std::vector<int> augmented_outputs = interpreter->outputs();
+
+    // All state tensors input/output need to be treated as model input/output.
+    augmented_inputs.insert(augmented_inputs.end(),
+                            model_states_inputs_.begin(),
+                            model_states_inputs_.end());
+    augmented_outputs.insert(augmented_outputs.end(),
+                             model_states_outputs_.begin(),
+                             model_states_outputs_.end());
+
     CHECK_NN(ANeuralNetworksModel_identifyInputsAndOutputs(
-        nn_model_, static_cast<uint32_t>(interpreter->inputs().size()),
-        reinterpret_cast<const uint32_t*>(interpreter->inputs().data()),
-        static_cast<uint32_t>(interpreter->outputs().size()),
-        reinterpret_cast<const uint32_t*>(interpreter->outputs().data())));
+        nn_model_, static_cast<uint32_t>(augmented_inputs.size()),
+        reinterpret_cast<const uint32_t*>(augmented_inputs.data()),
+        static_cast<uint32_t>(augmented_outputs.size()),
+        reinterpret_cast<const uint32_t*>(augmented_outputs.data())));
     CHECK_NN(ANeuralNetworksModel_finish(nn_model_));
   }
   if (!nn_compiled_model_) {
@@ -507,6 +578,7 @@ TfLiteStatus NNAPIDelegate::Invoke(Interpreter* interpreter) {
     CHECK_NN(ANeuralNetworksExecution_setInput(
         execution, i, nullptr, tensor->data.raw, tensor->bytes));
   }
+
   // Tell nn api where to place final data.
   for (size_t i = 0; i < interpreter->outputs().size(); i++) {
     int output = interpreter->outputs()[i];
@@ -514,6 +586,24 @@ TfLiteStatus NNAPIDelegate::Invoke(Interpreter* interpreter) {
     CHECK_NN(ANeuralNetworksExecution_setOutput(
         execution, i, nullptr, tensor->data.raw, tensor->bytes));
   }
+
+  // The state_out of previous invocation need to be mapped to state_in of
+  // current invocation.
+  for (size_t i = 0; i < model_states_outputs_.size(); i++) {
+    int state_tensor_idx = model_states_outputs_[i];
+    TfLiteTensor* tensor = interpreter->tensor(state_tensor_idx);
+    // Here we are using a deep copy for state_in tensors so that we are not
+    // reading and writing into the same buffer during a invocation.
+    // TODO(miaowang): using double shared buffer to minimize the copies.
+    CHECK_NN(ANeuralNetworksExecution_setInput(
+        execution, i + interpreter->inputs().size(), nullptr, tensor->data.raw,
+        tensor->bytes));
+    // Tell NNAPI where to output the state_out.
+    CHECK_NN(ANeuralNetworksExecution_setOutput(
+        execution, i + interpreter->outputs().size(), nullptr, tensor->data.raw,
+        tensor->bytes));
+  }
+
   // Currently use blocking compute.
   ANeuralNetworksEvent* event = nullptr;
   CHECK_NN(ANeuralNetworksExecution_startCompute(execution, &event));
index e980009..94dea4f 100644 (file)
@@ -59,6 +59,14 @@ class NNAPIDelegate {
   ANeuralNetworksModel* nn_model_ = nullptr;
   // The NN API compilation handle
   ANeuralNetworksCompilation* nn_compiled_model_ = nullptr;
+
+  // List of state tensors for LSTM, RNN, SVDF.
+  // NN API does not allow ops to maintain states across multiple
+  // invocations. We need to manually create state input tensors from
+  // corresponding state output tensors of TFLite operations, and map them
+  // correctly.
+  std::vector<int> model_states_inputs_;
+  std::vector<int> model_states_outputs_;
 };
 
 }  // namespace tflite