A more efficient implementation of the Op using batch operations.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Sat, 3 Feb 2018 03:40:30 +0000 (19:40 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 3 Feb 2018 03:44:35 +0000 (19:44 -0800)
PiperOrigin-RevId: 184367562

tensorflow/contrib/lite/kernels/BUILD
tensorflow/contrib/lite/kernels/basic_rnn.cc
tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
tensorflow/contrib/lite/kernels/internal/BUILD
tensorflow/contrib/lite/kernels/internal/kernel_utils.cc [new file with mode: 0644]
tensorflow/contrib/lite/kernels/internal/kernel_utils.h [new file with mode: 0644]
tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc

index 8c40adf..a8ef0da 100644 (file)
@@ -156,6 +156,7 @@ cc_library(
         "//tensorflow/contrib/lite:framework",
         "//tensorflow/contrib/lite:string_util",
         "//tensorflow/contrib/lite/kernels:gemm_support",
+        "//tensorflow/contrib/lite/kernels/internal:kernel_utils",
         "//tensorflow/contrib/lite/kernels/internal:optimized",
         "//tensorflow/contrib/lite/kernels/internal:optimized_base",
         "//tensorflow/contrib/lite/kernels/internal:quantization_util",
index a0391e0..2c5074e 100644 (file)
@@ -23,6 +23,7 @@ limitations under the License.
 #include "tensorflow/contrib/lite/builtin_op_data.h"
 #include "tensorflow/contrib/lite/context.h"
 #include "tensorflow/contrib/lite/kernels/activation_functor.h"
+#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
 #include "tensorflow/contrib/lite/kernels/op_macros.h"
 
 namespace tflite {
@@ -101,50 +102,20 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
   const int batch_size = input->dims->data[0];
   const int num_units = input_weights->dims->data[0];
   const int input_size = input->dims->data[1];
-  const int input_weights_stride = input_weights->dims->data[1];
-  const int recurrent_weights_stride = recurrent_weights->dims->data[1];
-
-  // For each batch
-  for (int b = 0; b < batch_size; b++) {
-    // Initialize the pointer to input, output and bias.
-    const float* input_ptr_batch = input->data.f + b * input_size;
-    float* output_ptr_batch = output->data.f + b * num_units;
-    float* hidden_state_ptr_batch = hidden_state->data.f + b * num_units;
-
-    // Initialize input_weights and recurrent_weights.
-    const float* input_weights_ptr = input_weights->data.f;
-    const float* recurrent_weights_ptr = recurrent_weights->data.f;
-
-    // Output = bias
-    for (int o = 0; o < num_units; o++) {
-      output_ptr_batch[o] = bias_ptr[o];
-    }
-
-    // Output += input * input_weights
-    for (int o = 0; o < num_units; o++) {
-      for (int i = 0; i < input_size; i++) {
-        output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i];
-      }
-      input_weights_ptr += input_weights_stride;
-    }
-
-    // Output += recurrent_weights * hidden_state
-    for (int o = 0; o < num_units; o++) {
-      for (int h = 0; h < num_units; h++) {
-        output_ptr_batch[o] +=
-            hidden_state_ptr_batch[h] * recurrent_weights_ptr[h];
-      }
-      recurrent_weights_ptr += recurrent_weights_stride;
-    }
-
-    // Output = activation(Output) and update hidden_state
-    for (int o = 0; o < num_units; o++) {
-      output_ptr_batch[o] =
-          (ActivationFunctor(params->activation))(output_ptr_batch[o]);
-      hidden_state_ptr_batch[o] = output_ptr_batch[o];
-    }
-  }
 
+  // Initialize the pointer to hidden state.
+  float* hidden_state_ptr_batch = hidden_state->data.f;
+  // Initialize the pointer to input and output.
+  const float* input_ptr_batch = input->data.f;
+  float* output_ptr_batch = output->data.f;
+  // Initialize input_weights and recurrent_weights.
+  const float* input_weights_ptr = input_weights->data.f;
+  const float* recurrent_weights_ptr = recurrent_weights->data.f;
+
+  kernel_utils::RnnBatchStep(input_ptr_batch, input_weights_ptr,
+                             recurrent_weights_ptr, bias_ptr, input_size,
+                             num_units, batch_size, params->activation,
+                             hidden_state_ptr_batch, output_ptr_batch);
   return kTfLiteOk;
 }
 
index f540816..aa24c1f 100644 (file)
@@ -23,6 +23,7 @@ limitations under the License.
 #include "tensorflow/contrib/lite/builtin_op_data.h"
 #include "tensorflow/contrib/lite/context.h"
 #include "tensorflow/contrib/lite/kernels/activation_functor.h"
+#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
 #include "tensorflow/contrib/lite/kernels/op_macros.h"
 
 namespace tflite {
@@ -119,47 +120,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
   return kTfLiteOk;
 }
 
-namespace {
-// Performs one RNN computation step for the input specified by input_ptr_batch.
-// The RNN cell is specified by the pointers to its weights and biases, along
-// with the input size, number of units, strides, activation.
-// The pointers to the hidden state and the output are updated as a result.
-// TODO(mirkov): factor out this function to a shared library.
-void RnnStep(const float* input_ptr_batch, const float* input_weights_ptr,
-             const float* recurrent_weights_ptr, const float* bias_ptr,
-             int input_size, int num_units, int input_weights_stride,
-             int recurrent_weights_stride, TfLiteFusedActivation activation,
-             float* hidden_state_ptr_batch, float* output_ptr_batch) {
-  // Output = bias
-  for (int o = 0; o < num_units; o++) {
-    output_ptr_batch[o] = bias_ptr[o];
-  }
-
-  // Output += input * input_weights
-  for (int o = 0; o < num_units; o++) {
-    for (int i = 0; i < input_size; i++) {
-      output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i];
-    }
-    input_weights_ptr += input_weights_stride;
-  }
-
-  // Output += recurrent_weights * hidden_state
-  for (int o = 0; o < num_units; o++) {
-    for (int h = 0; h < num_units; h++) {
-      output_ptr_batch[o] +=
-          hidden_state_ptr_batch[h] * recurrent_weights_ptr[h];
-    }
-    recurrent_weights_ptr += recurrent_weights_stride;
-  }
-
-  // Output = activation(Output) and update hidden_state
-  for (int o = 0; o < num_units; o++) {
-    output_ptr_batch[o] = (ActivationFunctor(activation))(output_ptr_batch[o]);
-    hidden_state_ptr_batch[o] = output_ptr_batch[o];
-  }
-}
-}  // namespace
-
 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
   auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data);
 
@@ -189,15 +149,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
   const int input_size = input->dims->data[2];
 
   const int fw_num_units = fw_input_weights->dims->data[0];
-  const int fw_input_weights_stride = fw_input_weights->dims->data[1];
-  const int fw_recurrent_weights_stride = fw_recurrent_weights->dims->data[1];
   const float* fw_bias_ptr = fw_bias->data.f;
   const float* fw_input_weights_ptr = fw_input_weights->data.f;
   const float* fw_recurrent_weights_ptr = fw_recurrent_weights->data.f;
 
   const int bw_num_units = bw_input_weights->dims->data[0];
-  const int bw_input_weights_stride = bw_input_weights->dims->data[1];
-  const int bw_recurrent_weights_stride = bw_recurrent_weights->dims->data[1];
   const float* bw_bias_ptr = bw_bias->data.f;
   const float* bw_input_weights_ptr = bw_input_weights->data.f;
   const float* bw_recurrent_weights_ptr = bw_recurrent_weights->data.f;
@@ -212,10 +168,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
       float* output_ptr_batch =
           fw_output->data.f + b * fw_num_units * max_time + s * fw_num_units;
 
-      RnnStep(input_ptr_batch, fw_input_weights_ptr, fw_recurrent_weights_ptr,
-              fw_bias_ptr, input_size, fw_num_units, fw_input_weights_stride,
-              fw_recurrent_weights_stride, params->activation,
-              fw_hidden_state_ptr_batch, output_ptr_batch);
+      kernel_utils::RnnBatchStep(
+          input_ptr_batch, fw_input_weights_ptr, fw_recurrent_weights_ptr,
+          fw_bias_ptr, input_size, fw_num_units, /*batch_size=*/1,
+          params->activation, fw_hidden_state_ptr_batch, output_ptr_batch);
     }
     // Backward cell.
     float* bw_hidden_state_ptr_batch =
@@ -226,10 +182,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
       float* output_ptr_batch =
           bw_output->data.f + b * bw_num_units * max_time + s * bw_num_units;
 
-      RnnStep(input_ptr_batch, bw_input_weights_ptr, bw_recurrent_weights_ptr,
-              bw_bias_ptr, input_size, bw_num_units, bw_input_weights_stride,
-              bw_recurrent_weights_stride, params->activation,
-              bw_hidden_state_ptr_batch, output_ptr_batch);
+      kernel_utils::RnnBatchStep(
+          input_ptr_batch, bw_input_weights_ptr, bw_recurrent_weights_ptr,
+          bw_bias_ptr, input_size, bw_num_units, /*batch_size=*/1,
+          params->activation, bw_hidden_state_ptr_batch, output_ptr_batch);
     }
   }
   return kTfLiteOk;
index 288f1f8..4691a54 100644 (file)
@@ -291,6 +291,16 @@ cc_library(
 )
 
 cc_library(
+    name = "kernel_utils",
+    srcs = ["kernel_utils.cc"],
+    hdrs = ["kernel_utils.h"],
+    deps = [
+        ":tensor_utils",
+        "//tensorflow/contrib/lite:builtin_op_data",
+    ],
+)
+
+cc_library(
     name = "tensor_utils",
     srcs = [
         "tensor_utils.cc",
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
new file mode 100644 (file)
index 0000000..5103951
--- /dev/null
@@ -0,0 +1,44 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
+
+namespace tflite {
+namespace kernel_utils {
+
+void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr,
+                  const float* recurrent_weights_ptr, const float* bias_ptr,
+                  int input_size, int num_units, int batch_size,
+                  TfLiteFusedActivation activation,
+                  float* hidden_state_ptr_batch, float* output_ptr_batch) {
+  // Output = bias
+  tensor_utils::VectorBatchVectorAssign(bias_ptr, num_units, batch_size,
+                                        output_ptr_batch);
+  // Output += input * input_weights
+  tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+      input_weights_ptr, num_units, input_size, input_ptr_batch, batch_size,
+      output_ptr_batch, /*result_stride=*/1);
+  // Output += recurrent_weights * hidden_state
+  tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+      recurrent_weights_ptr, num_units, num_units, hidden_state_ptr_batch,
+      batch_size, output_ptr_batch, /*result_stride=*/1);
+  // Output = activation(Output) and update hidden_state
+  tensor_utils::ApplyActivationToVector(
+      output_ptr_batch, num_units * batch_size, activation, output_ptr_batch);
+  tensor_utils::VectorBatchVectorAssign(output_ptr_batch, num_units, batch_size,
+                                        hidden_state_ptr_batch);
+}
+
+}  // namespace kernel_utils
+}  // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
new file mode 100644 (file)
index 0000000..9872d45
--- /dev/null
@@ -0,0 +1,40 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+
+namespace tflite {
+namespace kernel_utils {
+
+// Performs an RNN batch inference step for inputs specified by input_ptr_batch.
+// The RNN cell is specified by the pointers to its input and recurrent weights,
+// and biases, along with the input size, number of units, activation.
+//
+// The pointers to the hidden state and the output are updated as a result.
+//
+// The pointers with the suffix "_batch" point to data aligned in batch_major
+// order, and each step processes batch_size many inputs from input_ptr_batch,
+// and updates batch_size many outputs and hidden states.
+void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr,
+                  const float* recurrent_weights_ptr, const float* bias_ptr,
+                  int input_size, int num_units, int batch_size,
+                  TfLiteFusedActivation activation,
+                  float* hidden_state_ptr_batch, float* output_ptr_batch);
+
+}  // namespace kernel_utils
+}  // namespace tflite
+#endif  // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_
index 7ce87e4..ac00c37 100644 (file)
@@ -23,6 +23,7 @@ limitations under the License.
 #include "tensorflow/contrib/lite/builtin_op_data.h"
 #include "tensorflow/contrib/lite/context.h"
 #include "tensorflow/contrib/lite/kernels/activation_functor.h"
+#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
 #include "tensorflow/contrib/lite/kernels/op_macros.h"
 
 namespace tflite {
@@ -88,42 +89,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
   return kTfLiteOk;
 }
 
-namespace {
-void RnnStep(const float* input_ptr_batch, const float* input_weights_ptr,
-             const float* recurrent_weights_ptr, const float* bias_ptr,
-             int input_size, int num_units, int input_weights_stride,
-             int recurrent_weights_stride, TfLiteFusedActivation activation,
-             float* hidden_state_ptr_batch, float* output_ptr_batch) {
-  // Output = bias
-  for (int o = 0; o < num_units; o++) {
-    output_ptr_batch[o] = bias_ptr[o];
-  }
-
-  // Output += input * input_weights
-  for (int o = 0; o < num_units; o++) {
-    for (int i = 0; i < input_size; i++) {
-      output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i];
-    }
-    input_weights_ptr += input_weights_stride;
-  }
-
-  // Output += recurrent_weights * hidden_state
-  for (int o = 0; o < num_units; o++) {
-    for (int h = 0; h < num_units; h++) {
-      output_ptr_batch[o] +=
-          hidden_state_ptr_batch[h] * recurrent_weights_ptr[h];
-    }
-    recurrent_weights_ptr += recurrent_weights_stride;
-  }
-
-  // Output = activation(Output) and update hidden_state
-  for (int o = 0; o < num_units; o++) {
-    output_ptr_batch[o] = (ActivationFunctor(activation))(output_ptr_batch[o]);
-    hidden_state_ptr_batch[o] = output_ptr_batch[o];
-  }
-}
-}  // namespace
-
 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
   auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data);
 
@@ -147,30 +112,25 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
       (time_major) ? input->dims->data[0] : input->dims->data[1];
   const int num_units = input_weights->dims->data[0];
   const int input_size = input->dims->data[2];
-  const int input_weights_stride = input_weights->dims->data[1];
-  const int recurrent_weights_stride = recurrent_weights->dims->data[1];
 
   // Initialize input_weights and recurrent_weights.
   const float* input_weights_ptr = input_weights->data.f;
   const float* recurrent_weights_ptr = recurrent_weights->data.f;
 
   if (time_major) {
-    // Unroll the sequence
+    // Initialize the pointer to hidden state.
+    float* hidden_state_ptr_batch = hidden_state->data.f;
+    // Unroll the sequence and use batch batch operations for efficiency.
     for (int s = 0; s < max_time; s++) {
-      for (int b = 0; b < batch_size; b++) {
-        // Initialize the pointer to hidden state.
-        float* hidden_state_ptr_batch = hidden_state->data.f + b * num_units;
-        // Initialize the pointer to input and output.
-        const float* input_ptr_batch =
-            input->data.f + s * input_size * batch_size + b * input_size;
-        float* output_ptr_batch =
-            output->data.f + s * num_units * batch_size + b * num_units;
-
-        RnnStep(input_ptr_batch, input_weights_ptr, recurrent_weights_ptr,
-                bias_ptr, input_size, num_units, input_weights_stride,
-                recurrent_weights_stride, params->activation,
-                hidden_state_ptr_batch, output_ptr_batch);
-      }
+      // Initialize the pointer to input and output.
+      const float* input_ptr_batch =
+          input->data.f + s * input_size * batch_size;
+      float* output_ptr_batch = output->data.f + s * num_units * batch_size;
+
+      kernel_utils::RnnBatchStep(input_ptr_batch, input_weights_ptr,
+                                 recurrent_weights_ptr, bias_ptr, input_size,
+                                 num_units, batch_size, params->activation,
+                                 hidden_state_ptr_batch, output_ptr_batch);
     }
   } else {
     // For each batch
@@ -184,10 +144,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
         float* output_ptr_batch =
             output->data.f + b * num_units * max_time + s * num_units;
 
-        RnnStep(input_ptr_batch, input_weights_ptr, recurrent_weights_ptr,
-                bias_ptr, input_size, num_units, input_weights_stride,
-                recurrent_weights_stride, params->activation,
-                hidden_state_ptr_batch, output_ptr_batch);
+        kernel_utils::RnnBatchStep(
+            input_ptr_batch, input_weights_ptr, recurrent_weights_ptr, bias_ptr,
+            input_size, num_units, /*batch_size=*/1, params->activation,
+            hidden_state_ptr_batch, output_ptr_batch);
       }
     }
   }