From 39010bef7f72709a87a275060878baac815744c2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 2 Feb 2018 19:40:30 -0800 Subject: [PATCH] A more efficient implementation of the Op using batch operations. PiperOrigin-RevId: 184367562 --- tensorflow/contrib/lite/kernels/BUILD | 1 + tensorflow/contrib/lite/kernels/basic_rnn.cc | 57 ++++------------- .../lite/kernels/bidirectional_sequence_rnn.cc | 62 +++--------------- tensorflow/contrib/lite/kernels/internal/BUILD | 10 +++ .../contrib/lite/kernels/internal/kernel_utils.cc | 44 +++++++++++++ .../contrib/lite/kernels/internal/kernel_utils.h | 40 ++++++++++++ .../lite/kernels/unidirectional_sequence_rnn.cc | 74 +++++----------------- 7 files changed, 135 insertions(+), 153 deletions(-) create mode 100644 tensorflow/contrib/lite/kernels/internal/kernel_utils.cc create mode 100644 tensorflow/contrib/lite/kernels/internal/kernel_utils.h diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index 8c40adf..a8ef0da 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -156,6 +156,7 @@ cc_library( "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:string_util", "//tensorflow/contrib/lite/kernels:gemm_support", + "//tensorflow/contrib/lite/kernels/internal:kernel_utils", "//tensorflow/contrib/lite/kernels/internal:optimized", "//tensorflow/contrib/lite/kernels/internal:optimized_base", "//tensorflow/contrib/lite/kernels/internal:quantization_util", diff --git a/tensorflow/contrib/lite/kernels/basic_rnn.cc b/tensorflow/contrib/lite/kernels/basic_rnn.cc index a0391e0..2c5074e 100644 --- a/tensorflow/contrib/lite/kernels/basic_rnn.cc +++ b/tensorflow/contrib/lite/kernels/basic_rnn.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" namespace tflite { @@ -101,50 +102,20 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const int batch_size = input->dims->data[0]; const int num_units = input_weights->dims->data[0]; const int input_size = input->dims->data[1]; - const int input_weights_stride = input_weights->dims->data[1]; - const int recurrent_weights_stride = recurrent_weights->dims->data[1]; - - // For each batch - for (int b = 0; b < batch_size; b++) { - // Initialize the pointer to input, output and bias. - const float* input_ptr_batch = input->data.f + b * input_size; - float* output_ptr_batch = output->data.f + b * num_units; - float* hidden_state_ptr_batch = hidden_state->data.f + b * num_units; - - // Initialize input_weights and recurrent_weights. - const float* input_weights_ptr = input_weights->data.f; - const float* recurrent_weights_ptr = recurrent_weights->data.f; - - // Output = bias - for (int o = 0; o < num_units; o++) { - output_ptr_batch[o] = bias_ptr[o]; - } - - // Output += input * input_weights - for (int o = 0; o < num_units; o++) { - for (int i = 0; i < input_size; i++) { - output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i]; - } - input_weights_ptr += input_weights_stride; - } - - // Output += recurrent_weights * hidden_state - for (int o = 0; o < num_units; o++) { - for (int h = 0; h < num_units; h++) { - output_ptr_batch[o] += - hidden_state_ptr_batch[h] * recurrent_weights_ptr[h]; - } - recurrent_weights_ptr += recurrent_weights_stride; - } - - // Output = activation(Output) and update hidden_state - for (int o = 0; o < num_units; o++) { - output_ptr_batch[o] = - (ActivationFunctor(params->activation))(output_ptr_batch[o]); - hidden_state_ptr_batch[o] = output_ptr_batch[o]; - } - } + // Initialize the pointer to hidden state. + float* hidden_state_ptr_batch = hidden_state->data.f; + // Initialize the pointer to input and output. + const float* input_ptr_batch = input->data.f; + float* output_ptr_batch = output->data.f; + // Initialize input_weights and recurrent_weights. + const float* input_weights_ptr = input_weights->data.f; + const float* recurrent_weights_ptr = recurrent_weights->data.f; + + kernel_utils::RnnBatchStep(input_ptr_batch, input_weights_ptr, + recurrent_weights_ptr, bias_ptr, input_size, + num_units, batch_size, params->activation, + hidden_state_ptr_batch, output_ptr_batch); return kTfLiteOk; } diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc index f540816..aa24c1f 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" namespace tflite { @@ -119,47 +120,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -namespace { -// Performs one RNN computation step for the input specified by input_ptr_batch. -// The RNN cell is specified by the pointers to its weights and biases, along -// with the input size, number of units, strides, activation. -// The pointers to the hidden state and the output are updated as a result. -// TODO(mirkov): factor out this function to a shared library. -void RnnStep(const float* input_ptr_batch, const float* input_weights_ptr, - const float* recurrent_weights_ptr, const float* bias_ptr, - int input_size, int num_units, int input_weights_stride, - int recurrent_weights_stride, TfLiteFusedActivation activation, - float* hidden_state_ptr_batch, float* output_ptr_batch) { - // Output = bias - for (int o = 0; o < num_units; o++) { - output_ptr_batch[o] = bias_ptr[o]; - } - - // Output += input * input_weights - for (int o = 0; o < num_units; o++) { - for (int i = 0; i < input_size; i++) { - output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i]; - } - input_weights_ptr += input_weights_stride; - } - - // Output += recurrent_weights * hidden_state - for (int o = 0; o < num_units; o++) { - for (int h = 0; h < num_units; h++) { - output_ptr_batch[o] += - hidden_state_ptr_batch[h] * recurrent_weights_ptr[h]; - } - recurrent_weights_ptr += recurrent_weights_stride; - } - - // Output = activation(Output) and update hidden_state - for (int o = 0; o < num_units; o++) { - output_ptr_batch[o] = (ActivationFunctor(activation))(output_ptr_batch[o]); - hidden_state_ptr_batch[o] = output_ptr_batch[o]; - } -} -} // namespace - TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); @@ -189,15 +149,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const int input_size = input->dims->data[2]; const int fw_num_units = fw_input_weights->dims->data[0]; - const int fw_input_weights_stride = fw_input_weights->dims->data[1]; - const int fw_recurrent_weights_stride = fw_recurrent_weights->dims->data[1]; const float* fw_bias_ptr = fw_bias->data.f; const float* fw_input_weights_ptr = fw_input_weights->data.f; const float* fw_recurrent_weights_ptr = fw_recurrent_weights->data.f; const int bw_num_units = bw_input_weights->dims->data[0]; - const int bw_input_weights_stride = bw_input_weights->dims->data[1]; - const int bw_recurrent_weights_stride = bw_recurrent_weights->dims->data[1]; const float* bw_bias_ptr = bw_bias->data.f; const float* bw_input_weights_ptr = bw_input_weights->data.f; const float* bw_recurrent_weights_ptr = bw_recurrent_weights->data.f; @@ -212,10 +168,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { float* output_ptr_batch = fw_output->data.f + b * fw_num_units * max_time + s * fw_num_units; - RnnStep(input_ptr_batch, fw_input_weights_ptr, fw_recurrent_weights_ptr, - fw_bias_ptr, input_size, fw_num_units, fw_input_weights_stride, - fw_recurrent_weights_stride, params->activation, - fw_hidden_state_ptr_batch, output_ptr_batch); + kernel_utils::RnnBatchStep( + input_ptr_batch, fw_input_weights_ptr, fw_recurrent_weights_ptr, + fw_bias_ptr, input_size, fw_num_units, /*batch_size=*/1, + params->activation, fw_hidden_state_ptr_batch, output_ptr_batch); } // Backward cell. float* bw_hidden_state_ptr_batch = @@ -226,10 +182,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { float* output_ptr_batch = bw_output->data.f + b * bw_num_units * max_time + s * bw_num_units; - RnnStep(input_ptr_batch, bw_input_weights_ptr, bw_recurrent_weights_ptr, - bw_bias_ptr, input_size, bw_num_units, bw_input_weights_stride, - bw_recurrent_weights_stride, params->activation, - bw_hidden_state_ptr_batch, output_ptr_batch); + kernel_utils::RnnBatchStep( + input_ptr_batch, bw_input_weights_ptr, bw_recurrent_weights_ptr, + bw_bias_ptr, input_size, bw_num_units, /*batch_size=*/1, + params->activation, bw_hidden_state_ptr_batch, output_ptr_batch); } } return kTfLiteOk; diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD index 288f1f8..4691a54 100644 --- a/tensorflow/contrib/lite/kernels/internal/BUILD +++ b/tensorflow/contrib/lite/kernels/internal/BUILD @@ -291,6 +291,16 @@ cc_library( ) cc_library( + name = "kernel_utils", + srcs = ["kernel_utils.cc"], + hdrs = ["kernel_utils.h"], + deps = [ + ":tensor_utils", + "//tensorflow/contrib/lite:builtin_op_data", + ], +) + +cc_library( name = "tensor_utils", srcs = [ "tensor_utils.cc", diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc new file mode 100644 index 0000000..5103951 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc @@ -0,0 +1,44 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" + +namespace tflite { +namespace kernel_utils { + +void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr, + const float* recurrent_weights_ptr, const float* bias_ptr, + int input_size, int num_units, int batch_size, + TfLiteFusedActivation activation, + float* hidden_state_ptr_batch, float* output_ptr_batch) { + // Output = bias + tensor_utils::VectorBatchVectorAssign(bias_ptr, num_units, batch_size, + output_ptr_batch); + // Output += input * input_weights + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_weights_ptr, num_units, input_size, input_ptr_batch, batch_size, + output_ptr_batch, /*result_stride=*/1); + // Output += recurrent_weights * hidden_state + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_weights_ptr, num_units, num_units, hidden_state_ptr_batch, + batch_size, output_ptr_batch, /*result_stride=*/1); + // Output = activation(Output) and update hidden_state + tensor_utils::ApplyActivationToVector( + output_ptr_batch, num_units * batch_size, activation, output_ptr_batch); + tensor_utils::VectorBatchVectorAssign(output_ptr_batch, num_units, batch_size, + hidden_state_ptr_batch); +} + +} // namespace kernel_utils +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h new file mode 100644 index 0000000..9872d45 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h @@ -0,0 +1,40 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_ + +#include "tensorflow/contrib/lite/builtin_op_data.h" + +namespace tflite { +namespace kernel_utils { + +// Performs an RNN batch inference step for inputs specified by input_ptr_batch. +// The RNN cell is specified by the pointers to its input and recurrent weights, +// and biases, along with the input size, number of units, activation. +// +// The pointers to the hidden state and the output are updated as a result. +// +// The pointers with the suffix "_batch" point to data aligned in batch_major +// order, and each step processes batch_size many inputs from input_ptr_batch, +// and updates batch_size many outputs and hidden states. +void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr, + const float* recurrent_weights_ptr, const float* bias_ptr, + int input_size, int num_units, int batch_size, + TfLiteFusedActivation activation, + float* hidden_state_ptr_batch, float* output_ptr_batch); + +} // namespace kernel_utils +} // namespace tflite +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_ diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc index 7ce87e4..ac00c37 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" namespace tflite { @@ -88,42 +89,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -namespace { -void RnnStep(const float* input_ptr_batch, const float* input_weights_ptr, - const float* recurrent_weights_ptr, const float* bias_ptr, - int input_size, int num_units, int input_weights_stride, - int recurrent_weights_stride, TfLiteFusedActivation activation, - float* hidden_state_ptr_batch, float* output_ptr_batch) { - // Output = bias - for (int o = 0; o < num_units; o++) { - output_ptr_batch[o] = bias_ptr[o]; - } - - // Output += input * input_weights - for (int o = 0; o < num_units; o++) { - for (int i = 0; i < input_size; i++) { - output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i]; - } - input_weights_ptr += input_weights_stride; - } - - // Output += recurrent_weights * hidden_state - for (int o = 0; o < num_units; o++) { - for (int h = 0; h < num_units; h++) { - output_ptr_batch[o] += - hidden_state_ptr_batch[h] * recurrent_weights_ptr[h]; - } - recurrent_weights_ptr += recurrent_weights_stride; - } - - // Output = activation(Output) and update hidden_state - for (int o = 0; o < num_units; o++) { - output_ptr_batch[o] = (ActivationFunctor(activation))(output_ptr_batch[o]); - hidden_state_ptr_batch[o] = output_ptr_batch[o]; - } -} -} // namespace - TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); @@ -147,30 +112,25 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { (time_major) ? input->dims->data[0] : input->dims->data[1]; const int num_units = input_weights->dims->data[0]; const int input_size = input->dims->data[2]; - const int input_weights_stride = input_weights->dims->data[1]; - const int recurrent_weights_stride = recurrent_weights->dims->data[1]; // Initialize input_weights and recurrent_weights. const float* input_weights_ptr = input_weights->data.f; const float* recurrent_weights_ptr = recurrent_weights->data.f; if (time_major) { - // Unroll the sequence + // Initialize the pointer to hidden state. + float* hidden_state_ptr_batch = hidden_state->data.f; + // Unroll the sequence and use batch batch operations for efficiency. for (int s = 0; s < max_time; s++) { - for (int b = 0; b < batch_size; b++) { - // Initialize the pointer to hidden state. - float* hidden_state_ptr_batch = hidden_state->data.f + b * num_units; - // Initialize the pointer to input and output. - const float* input_ptr_batch = - input->data.f + s * input_size * batch_size + b * input_size; - float* output_ptr_batch = - output->data.f + s * num_units * batch_size + b * num_units; - - RnnStep(input_ptr_batch, input_weights_ptr, recurrent_weights_ptr, - bias_ptr, input_size, num_units, input_weights_stride, - recurrent_weights_stride, params->activation, - hidden_state_ptr_batch, output_ptr_batch); - } + // Initialize the pointer to input and output. + const float* input_ptr_batch = + input->data.f + s * input_size * batch_size; + float* output_ptr_batch = output->data.f + s * num_units * batch_size; + + kernel_utils::RnnBatchStep(input_ptr_batch, input_weights_ptr, + recurrent_weights_ptr, bias_ptr, input_size, + num_units, batch_size, params->activation, + hidden_state_ptr_batch, output_ptr_batch); } } else { // For each batch @@ -184,10 +144,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { float* output_ptr_batch = output->data.f + b * num_units * max_time + s * num_units; - RnnStep(input_ptr_batch, input_weights_ptr, recurrent_weights_ptr, - bias_ptr, input_size, num_units, input_weights_stride, - recurrent_weights_stride, params->activation, - hidden_state_ptr_batch, output_ptr_batch); + kernel_utils::RnnBatchStep( + input_ptr_batch, input_weights_ptr, recurrent_weights_ptr, bias_ptr, + input_size, num_units, /*batch_size=*/1, params->activation, + hidden_state_ptr_batch, output_ptr_batch); } } } -- 2.7.4