TFLite Delegate: Add an `allow_dynamic_tensors` parameter.
authorYu-Cheng Ling <ycling@google.com>
Mon, 19 Mar 2018 21:27:52 +0000 (14:27 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 19 Mar 2018 21:37:01 +0000 (14:37 -0700)
PiperOrigin-RevId: 189641833

tensorflow/contrib/lite/BUILD
tensorflow/contrib/lite/interpreter.cc
tensorflow/contrib/lite/interpreter.h
tensorflow/contrib/lite/interpreter_test.cc

index 5cfbb54..dafe6f1 100644 (file)
@@ -170,6 +170,7 @@ cc_test(
     deps = [
         ":framework",
         ":string_util",
+        "//tensorflow/contrib/lite/kernels:kernel_util",
         "//tensorflow/contrib/lite/kernels/internal:tensor_utils",
         "//tensorflow/contrib/lite/schema:schema_fbs",
         "//tensorflow/contrib/lite/testing:util",
index cee57bb..937c185 100644 (file)
@@ -356,7 +356,11 @@ TfLiteStatus Interpreter::AllocateTensors() {
   }
 
   TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors());
-  invokable_ = true;
+  if (state_ == kStateUninvokable) {
+    state_ = kStateInvokable;
+  }
+  TF_LITE_ENSURE(&context_, state_ == kStateInvokable ||
+                                state_ == kStateInvokableAndImmutable);
   return kTfLiteOk;
 }
 
@@ -364,7 +368,12 @@ TfLiteStatus Interpreter::AddNodeWithParameters(
     const std::vector<int>& inputs, const std::vector<int>& outputs,
     const char* init_data, size_t init_data_size, void* builtin_data,
     const TfLiteRegistration* registration, int* node_index) {
-  invokable_ = false;
+  if (state_ == kStateInvokableAndImmutable) {
+    ReportError(&context_,
+                "AddNodeWithParameters is disallowed when graph is immutable.");
+    return kTfLiteError;
+  }
+  state_ = kStateUninvokable;
 
   std::unique_ptr<void, decltype(free)*> builtin_data_deleter(builtin_data,
                                                               free);
@@ -420,12 +429,17 @@ TfLiteStatus Interpreter::AddNodeWithParameters(
 
 TfLiteStatus Interpreter::ResizeInputTensor(int tensor_index,
                                             const std::vector<int>& dims) {
+  if (state_ == kStateInvokableAndImmutable) {
+    ReportError(&context_,
+                "ResizeInputTensor is disallowed when graph is immutable.");
+    return kTfLiteError;
+  }
+  state_ = kStateUninvokable;
+
   // TODO(aselle): All bounds checks can be implemented as one-sided bounds
   // checks by casting to unsigned for efficiency. Profile before doing this.
-
   TF_LITE_ENSURE(&context_,
                  tensor_index < context_.tensors_size && tensor_index >= 0);
-  invokable_ = false;
   TfLiteIntArray* dims_lite = ConvertVectorToTfLiteIntArray(dims);
   return ResizeTensorImpl(&context_.tensors[tensor_index], dims_lite);
 }
@@ -490,7 +504,7 @@ TfLiteStatus Interpreter::Invoke() {
     ReportError(&context_, "Invoke called on model that is not consistent.");
     return kTfLiteError;
   }
-  if (!invokable_) {
+  if (state_ == kStateUninvokable) {
     ReportError(&context_, "Invoke called on model that is not ready.");
     return kTfLiteError;
   }
@@ -622,6 +636,13 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly(
     int tensor_index, TfLiteType type, const char* name, const int rank,
     const int* dims, TfLiteQuantizationParams quantization, const char* buffer,
     size_t bytes, const Allocation* allocation) {
+  if (state_ == kStateInvokableAndImmutable) {
+    ReportError(
+        &context_,
+        "SetTensorParametersReadOnly is disallowed when graph is immutable.");
+    return kTfLiteError;
+  }
+
   TF_LITE_ENSURE(&context_,
                  tensor_index < context_.tensors_size && tensor_index >= 0);
   // For most tensors we know exactly how much memory is necessary so we can
@@ -645,7 +666,7 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly(
     tensor.allocation_type = kTfLiteMmapRo;
     tensor.allocation = allocation;
   } else {
-    invokable_ = false;
+    state_ = kStateUninvokable;
     TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims),
                       quantization, const_cast<char*>(buffer), bytes,
                       kTfLiteMmapRo, allocation, &tensor);
@@ -660,7 +681,12 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly(
 TfLiteStatus Interpreter::SetTensorParametersReadWrite(
     int tensor_index, TfLiteType type, const char* name, const int rank,
     const int* dims, TfLiteQuantizationParams quantization) {
-  invokable_ = false;
+  if (state_ == kStateInvokableAndImmutable) {
+    ReportError(
+        &context_,
+        "SetTensorParametersReadWrite is disallowed when graph is immutable.");
+    return kTfLiteError;
+  }
   TF_LITE_ENSURE(&context_,
                  tensor_index < context_.tensors_size && tensor_index >= 0);
   size_t required_bytes = 0;
@@ -738,19 +764,55 @@ void Interpreter::SetNumThreads(int num_threads) {
   context_.recommended_num_threads = num_threads;
 }
 
-TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) {
+TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate,
+                                                  bool allow_dynamic_tensors) {
+  if (!allow_dynamic_tensors) {
+    int last_execution_plan_index_prepared;
+    TF_LITE_ENSURE_OK(&context_, PrepareOpsStartingAt(
+                                     0, &last_execution_plan_index_prepared));
+
+    bool has_dynamic_tensors = true;
+    // Dynamic tensors exist if not all nodes can be prepared.
+    if (last_execution_plan_index_prepared + 1 == execution_plan_.size()) {
+      // If all the nodes can be prepared, check if the last node has dynamic
+      // tensors.
+      int node_index = execution_plan_[last_execution_plan_index_prepared];
+      TfLiteNode& node = nodes_and_registration_[node_index].first;
+      if (!HasDynamicTensor(context_, node.outputs)) {
+        has_dynamic_tensors = false;
+      }
+    }
+    if (has_dynamic_tensors) {
+      ReportError(&context_, "Attempting to resize a fixed-size tensor.");
+      return kTfLiteError;
+    }
+  }
+
   // TODO(aselle): Consider if it is worth storing pointers to delegates.
-  // Setup additional context interface
+  // Setup additional context interface.
   context_.GetNodeAndRegistration = GetNodeAndRegistration;
   context_.ReplaceSubgraphsWithDelegateKernels =
       ReplaceSubgraphsWithDelegateKernels;
   context_.GetExecutionPlan = GetExecutionPlan;
 
   TfLiteStatus status = delegate->Prepare(&context_, delegate);
+
   // Remove additional context info.
   SetForbiddenContextFunction(&context_.GetNodeAndRegistration);
   SetForbiddenContextFunction(&context_.ReplaceSubgraphsWithDelegateKernels);
   SetForbiddenContextFunction(&context_.GetExecutionPlan);
+
+  TF_LITE_ENSURE_OK(&context_, status);
+
+  if (!allow_dynamic_tensors) {
+    TF_LITE_ENSURE_OK(&context_, AllocateTensors());
+    TF_LITE_ENSURE(&context_, state_ == kStateInvokable ||
+                                  state_ == kStateInvokableAndImmutable);
+    // After using a delegate which doesn't support dynamic tensors, make the
+    // entire graph immutable.
+    state_ = kStateInvokableAndImmutable;
+  }
+
   return status;
 }
 
index af14337..788546f 100644 (file)
@@ -272,7 +272,9 @@ class Interpreter {
   // Allow a delegate to look at the graph and modify the graph to handle
   // parts of the graph themselves. After this is called, the graph may
   // contain new nodes that replace 1 more nodes.
-  TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegate* delegate);
+  // WARNING: This is an experimental API and subject to change.
+  TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegate* delegate,
+                                       bool allow_dynamic_tensors = false);
 
   // Ensure the data in `tensor.data` is readable. In case delegate is used,
   // it might require to copy the data from delegate buffer to raw memory.
@@ -447,6 +449,20 @@ class Interpreter {
     }
   }
 
+  // The state of the Interpreter.
+  enum State {
+    // The interpreter isn't ready to be invoked.
+    // `AllocateTensor` need to be called to enter an invokable state.
+    kStateUninvokable = 0,
+    // The interpreter is ready to be invoked.
+    kStateInvokable,
+    // The interpreter is ready to be invoked, and graph can't be further
+    // modified. The interpreter will enter this state when calling
+    // `ModifyGraphWithDelegate` with `allow_dynamic_tensors=false`.
+    kStateInvokableAndImmutable,
+  };
+  State state_ = kStateUninvokable;
+
   // A pure C data structure used to communicate with the pure C plugin
   // interface. To avoid copying tensor metadata, this is also the definitive
   // structure to store tensors.
@@ -462,10 +478,6 @@ class Interpreter {
   // the tensor array.
   bool consistent_ = true;
 
-  // Whether the model is safe to invoke (if any errors occurred this
-  // will be false).
-  bool invokable_ = false;
-
   // Array of indices representing the tensors that are inputs to the
   // interpreter.
   std::vector<int> inputs_;
index 7a029c7..efb29d5 100644 (file)
@@ -17,9 +17,11 @@ limitations under the License.
 #include <gtest/gtest.h>
 #include "tensorflow/contrib/lite/error_reporter.h"
 #include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
 #include "tensorflow/contrib/lite/schema/schema_generated.h"
 #include "tensorflow/contrib/lite/string_util.h"
 #include "tensorflow/contrib/lite/testing/util.h"
+
 namespace tflite {
 namespace {
 
@@ -439,12 +441,12 @@ TEST(BasicInterpreter, ThreeStepAllocate) {
   // String-in String-out node.
   TfLiteRegistration reg_copy = {nullptr, nullptr, nullptr, nullptr};
   reg_copy.invoke = [](TfLiteContext* context, TfLiteNode* node) {
-    TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]];
-    TfLiteTensor* a1 = &context->tensors[node->outputs->data[0]];
+    TfLiteTensor* input = &context->tensors[node->inputs->data[0]];
+    TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
     DynamicBuffer buf;
-    StringRef str_ref = GetString(a0, 0);
+    StringRef str_ref = GetString(input, 0);
     buf.AddString(str_ref);
-    buf.WriteToTensor(a1);
+    buf.WriteToTensor(output);
     return kTfLiteOk;
   };
 
@@ -778,13 +780,17 @@ TfLiteRegistration AddOpRegistration() {
 
   reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
     // Set output size to input size
-    TfLiteTensor* tensor0 = &context->tensors[node->inputs->data[0]];
-    TfLiteTensor* tensor1 = &context->tensors[node->inputs->data[1]];
-    TfLiteTensor* tensor2 = &context->tensors[node->outputs->data[0]];
-    TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims);
-    TfLiteIntArray* newSizeOther = TfLiteIntArrayCopy(tensor1->dims);
-    TF_LITE_ENSURE_EQ(context, newSize->size, newSizeOther->size);
-    TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, tensor2, newSize));
+    TfLiteTensor* input1 = &context->tensors[node->inputs->data[0]];
+    TfLiteTensor* input2 = &context->tensors[node->inputs->data[1]];
+    TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
+
+    TF_LITE_ENSURE_EQ(context, input1->dims->size, input2->dims->size);
+    for (int i = 0; i < input1->dims->size; ++i) {
+      TF_LITE_ENSURE_EQ(context, input1->dims->data[i], input2->dims->data[i]);
+    }
+
+    TF_LITE_ENSURE_STATUS(context->ResizeTensor(
+        context, output, TfLiteIntArrayCopy(input1->dims)));
     return kTfLiteOk;
   };
 
@@ -818,6 +824,8 @@ class TestDelegate : public ::testing::Test {
                                                quant);
     interpreter_->SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {3},
                                                quant);
+    interpreter_->SetTensorParametersReadWrite(4, kTfLiteFloat32, "", {3},
+                                               quant);
     TfLiteRegistration reg = AddOpRegistration();
     interpreter_->AddNodeWithParameters({0, 0}, {2}, nullptr, 0, nullptr, &reg);
     interpreter_->AddNodeWithParameters({1, 1}, {3}, nullptr, 0, nullptr, &reg);
@@ -916,7 +924,6 @@ class TestDelegate : public ::testing::Test {
 };
 
 TEST_F(TestDelegate, BasicDelegate) {
-  interpreter_->Invoke();
   delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
   interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate());
 
@@ -944,7 +951,6 @@ TEST_F(TestDelegate, BasicDelegate) {
 }
 
 TEST_F(TestDelegate, ComplexDeligate) {
-  interpreter_->Invoke();
   delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({1, 2}));
   interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate());
 
@@ -959,7 +965,6 @@ TEST_F(TestDelegate, ComplexDeligate) {
 }
 
 TEST_F(TestDelegate, SetBufferHandleToInput) {
-  interpreter_->Invoke();
   delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
   TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate();
   interpreter_->ModifyGraphWithDelegate(delegate);
@@ -978,7 +983,6 @@ TEST_F(TestDelegate, SetBufferHandleToInput) {
 }
 
 TEST_F(TestDelegate, SetBufferHandleToOutput) {
-  interpreter_->Invoke();
   delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
   TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate();
   interpreter_->ModifyGraphWithDelegate(delegate);
@@ -1002,7 +1006,7 @@ TEST_F(TestDelegate, SetInvalidHandleToTensor) {
   interpreter_->Invoke();
   delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
   TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate();
-  interpreter_->ModifyGraphWithDelegate(delegate);
+  interpreter_->ModifyGraphWithDelegate(delegate, true);
 
   SimpleDelegate another_simple_delegate({0, 1, 2});
 
@@ -1023,6 +1027,88 @@ TEST_F(TestDelegate, SetInvalidHandleToTensor) {
   EXPECT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle);
 }
 
+TEST_F(TestDelegate, ResizeInputWithNonDynamicDelegateShouldFail) {
+  delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
+  ASSERT_EQ(interpreter_->ResizeInputTensor(0, {1, 2}), kTfLiteOk);
+  ASSERT_EQ(interpreter_->ResizeInputTensor(1, {1, 2}), kTfLiteOk);
+  ASSERT_EQ(
+      interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()),
+      kTfLiteOk);
+  ASSERT_EQ(interpreter_->ResizeInputTensor(0, {1, 2}), kTfLiteError);
+}
+
+class TestDelegateWithDynamicTensors : public ::testing::Test {
+ protected:
+  void SetUp() override {
+    interpreter_.reset(new Interpreter);
+
+    interpreter_->AddTensors(2);
+    interpreter_->SetInputs({0});
+    interpreter_->SetOutputs({1});
+    TfLiteQuantizationParams quant;
+    interpreter_->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3},
+                                               quant);
+    interpreter_->SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3},
+                                               quant);
+    TfLiteRegistration reg = DynamicCopyOpRegistration();
+    interpreter_->AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, &reg);
+
+    delegate_.Prepare = [](TfLiteContext* context,
+                           TfLiteDelegate* delegate) -> TfLiteStatus {
+      // In this test, the delegate replaces all the nodes if this function is
+      // called.
+      TfLiteIntArray* execution_plan;
+      TF_LITE_ENSURE_STATUS(
+          context->GetExecutionPlan(context, &execution_plan));
+      context->ReplaceSubgraphsWithDelegateKernels(
+          context, DelegateRegistration(), execution_plan, delegate);
+      return kTfLiteOk;
+    };
+  }
+
+  static TfLiteRegistration DynamicCopyOpRegistration() {
+    TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
+
+    reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
+      TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
+      SetTensorToDynamic(output);
+      return kTfLiteOk;
+    };
+
+    reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
+      // Not implemented since this isn't required in testing.
+      return kTfLiteOk;
+    };
+    return reg;
+  }
+
+  static TfLiteRegistration DelegateRegistration() {
+    TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
+    return reg;
+  }
+
+  std::unique_ptr<Interpreter> interpreter_;
+  TfLiteDelegate delegate_;
+};
+
+TEST_F(TestDelegateWithDynamicTensors, DisallowDynamicTensors) {
+  interpreter_->ModifyGraphWithDelegate(&delegate_, false);
+
+  ASSERT_EQ(interpreter_->execution_plan().size(), 1);
+  // The interpreter should not call delegate's `Prepare` when dynamic tensors
+  // exist. So the node ID isn't changed.
+  ASSERT_EQ(interpreter_->execution_plan()[0], 0);
+}
+
+TEST_F(TestDelegateWithDynamicTensors, AllowDynamicTensors) {
+  interpreter_->ModifyGraphWithDelegate(&delegate_, true);
+
+  ASSERT_EQ(interpreter_->execution_plan().size(), 1);
+  // The node should be replaced because dynamic tensors are allowed. Therefore
+  // only node ID in the execution plan is changed from 0 to 1.
+  ASSERT_EQ(interpreter_->execution_plan()[0], 1);
+}
+
 }  // namespace
 }  // namespace tflite