[TFLite] Don't require a std::vector for Interpreter::SetTensorParameters*.
authorEugene Brevdo <ebrevdo@google.com>
Mon, 12 Mar 2018 20:35:03 +0000 (13:35 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 12 Mar 2018 20:39:26 +0000 (13:39 -0700)
PiperOrigin-RevId: 188770522

tensorflow/contrib/lite/interpreter.cc
tensorflow/contrib/lite/interpreter.h
tensorflow/contrib/lite/util.cc
tensorflow/contrib/lite/util.h

index bbcd318..f03c1c9 100644 (file)
@@ -575,9 +575,9 @@ TfLiteStatus Interpreter::GetNodeAndRegistration(
 }
 
 TfLiteStatus Interpreter::SetTensorParametersReadOnly(
-    int tensor_index, TfLiteType type, const char* name,
-    const std::vector<int>& dims, TfLiteQuantizationParams quantization,
-    const char* buffer, size_t bytes, const Allocation* allocation) {
+    int tensor_index, TfLiteType type, const char* name, const int rank,
+    const int* dims, TfLiteQuantizationParams quantization, const char* buffer,
+    size_t bytes, const Allocation* allocation) {
   TF_LITE_ENSURE(&context_,
                  tensor_index < context_.tensors_size && tensor_index >= 0);
   // For most tensors we know exactly how much memory is necessary so we can
@@ -585,23 +585,24 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly(
   // because their sizes change with the contents of the individual strings.
   if (type != kTfLiteString) {
     size_t required_bytes;
-    TF_LITE_ENSURE_OK(&context_, BytesRequired(type, dims.data(), dims.size(),
-                                               &required_bytes));
+    TF_LITE_ENSURE_OK(&context_,
+                      BytesRequired(type, dims, rank, &required_bytes));
     TF_LITE_ENSURE_EQ(&context_, required_bytes, bytes);
   }
 
   TfLiteTensor& tensor = context_.tensors[tensor_index];
-  if (type == tensor.type && EqualVectorAndTfLiteIntArray(tensor.dims, dims)) {
+  if (type == tensor.type &&
+      EqualArrayAndTfLiteIntArray(tensor.dims, rank, dims)) {
     // Fast path which does not invalidate the invokable property.
     TfLiteTensorDataFree(&tensor);
     tensor.data.raw = const_cast<char*>(buffer);
-    if (!tensor.dims) tensor.dims = ConvertVectorToTfLiteIntArray(dims);
+    if (!tensor.dims) tensor.dims = ConvertArrayToTfLiteIntArray(rank, dims);
     tensor.params = quantization;
     tensor.allocation_type = kTfLiteMmapRo;
     tensor.allocation = allocation;
   } else {
     invokable_ = false;
-    TfLiteTensorReset(type, name, ConvertVectorToTfLiteIntArray(dims),
+    TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims),
                       quantization, const_cast<char*>(buffer), bytes,
                       kTfLiteMmapRo, allocation, &tensor);
   }
@@ -613,8 +614,8 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly(
 // bytes. The lifetime of buffer must be ensured to be greater or equal
 // to Interpreter.
 TfLiteStatus Interpreter::SetTensorParametersReadWrite(
-    int tensor_index, TfLiteType type, const char* name,
-    const std::vector<int>& dims, TfLiteQuantizationParams quantization) {
+    int tensor_index, TfLiteType type, const char* name, const int rank,
+    const int* dims, TfLiteQuantizationParams quantization) {
   invokable_ = false;
   TF_LITE_ENSURE(&context_,
                  tensor_index < context_.tensors_size && tensor_index >= 0);
@@ -624,10 +625,10 @@ TfLiteStatus Interpreter::SetTensorParametersReadWrite(
     // many bytes we will need based on the dimensions. String tensors are
     // allocated dynamically and we can't know ahead of time how much space
     // they will require.
-    TF_LITE_ENSURE_OK(&context_, BytesRequired(type, dims.data(), dims.size(),
-                                               &required_bytes));
+    TF_LITE_ENSURE_OK(&context_,
+                      BytesRequired(type, dims, rank, &required_bytes));
   }
-  TfLiteTensorReset(type, name, ConvertVectorToTfLiteIntArray(dims),
+  TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims),
                     quantization,
                     /*buffer=*/nullptr, required_bytes,
                     type == kTfLiteString ? kTfLiteDynamic : kTfLiteArenaRw,
index f2d4a05..7c5a195 100644 (file)
@@ -134,18 +134,34 @@ class Interpreter {
   // This variant assumes an external buffer has been allocated of size
   // bytes. The lifetime of buffer must be ensured to be greater or equal
   // to Interpreter.
-  TfLiteStatus SetTensorParametersReadOnly(
+  inline TfLiteStatus SetTensorParametersReadOnly(
       int tensor_index, TfLiteType type, const char* name,
       const std::vector<int>& dims, TfLiteQuantizationParams quantization,
+      const char* buffer, size_t bytes,
+      const Allocation* allocation = nullptr) {
+    return SetTensorParametersReadOnly(tensor_index, type, name, dims.size(),
+                                       dims.data(), quantization, buffer, bytes,
+                                       allocation);
+  };
+
+  TfLiteStatus SetTensorParametersReadOnly(
+      int tensor_index, TfLiteType type, const char* name, const int rank,
+      const int* dims, TfLiteQuantizationParams quantization,
       const char* buffer, size_t bytes, const Allocation* allocation = nullptr);
 
   // Set description of inputs/outputs/data/fptrs for node `node_index`.
   // This variant assumes an external buffer has been allocated of size
   // bytes. The lifetime of buffer must be ensured to be greater or equal
   // to Interpreter.
-  TfLiteStatus SetTensorParametersReadWrite(
+  inline TfLiteStatus SetTensorParametersReadWrite(
       int tensor_index, TfLiteType type, const char* name,
-      const std::vector<int>& dims, TfLiteQuantizationParams quantization);
+      const std::vector<int>& dims, TfLiteQuantizationParams quantization) {
+    return SetTensorParametersReadWrite(tensor_index, type, name, dims.size(),
+                                        dims.data(), quantization);
+  }
+  TfLiteStatus SetTensorParametersReadWrite(
+      int tensor_index, TfLiteType type, const char* name, const int rank,
+      const int* dims, TfLiteQuantizationParams quantization);
 
   // Functions to access tensor data
 
index b7f31e2..fb4af07 100644 (file)
@@ -17,17 +17,21 @@ limitations under the License.
 namespace tflite {
 
 TfLiteIntArray* ConvertVectorToTfLiteIntArray(const std::vector<int>& input) {
-  TfLiteIntArray* output = TfLiteIntArrayCreate(input.size());
-  for (size_t i = 0; i < input.size(); i++) {
-    output->data[i] = input[i];
+  return ConvertArrayToTfLiteIntArray(input.size(), input.data());
+}
+
+TfLiteIntArray* ConvertArrayToTfLiteIntArray(const int rank, const int* dims) {
+  TfLiteIntArray* output = TfLiteIntArrayCreate(rank);
+  for (size_t i = 0; i < rank; i++) {
+    output->data[i] = dims[i];
   }
   return output;
 }
 
-bool EqualVectorAndTfLiteIntArray(const TfLiteIntArray* a,
-                                  const std::vector<int>& b) {
+bool EqualArrayAndTfLiteIntArray(const TfLiteIntArray* a, const int b_size,
+                                 const int* b) {
   if (!a) return false;
-  if (a->size != b.size()) return false;
+  if (a->size != b_size) return false;
   for (int i = 0; i < a->size; ++i) {
     if (a->data[i] != b[i]) return false;
   }
index f505d82..a34db35 100644 (file)
@@ -29,9 +29,11 @@ namespace tflite {
 // Converts a `std::vector` to a `TfLiteIntArray`.
 TfLiteIntArray* ConvertVectorToTfLiteIntArray(const std::vector<int>& input);
 
-// Checks whether a `TfLiteIntArray` and `std::vector` have matching elements.
-bool EqualVectorAndTfLiteIntArray(const TfLiteIntArray* a,
-                                  const std::vector<int>& b);
+TfLiteIntArray* ConvertArrayToTfLiteIntArray(const int rank, const int* dims);
+
+// Checks whether a `TfLiteIntArray` and an int array have matching elements.
+bool EqualArrayAndTfLiteIntArray(const TfLiteIntArray* a, const int b_size,
+                                 const int* b);
 
 }  // namespace tflite