Return kTfLiteError if calling delegate-specific functions from non-delegate code.
authorYu-Cheng Ling <ycling@google.com>
Fri, 9 Mar 2018 00:16:47 +0000 (16:16 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 9 Mar 2018 00:21:23 +0000 (16:21 -0800)
PiperOrigin-RevId: 188407931

tensorflow/contrib/lite/interpreter.cc
tensorflow/contrib/lite/interpreter_test.cc

index 2834dc4..4710488 100644 (file)
@@ -30,6 +30,27 @@ limitations under the License.
 
 namespace tflite {
 
+namespace {
+
+// Stub method which returns kTfLiteError when the function is forbidden.
+// We're registrating this function to several different function to save
+// compiled binary size. Please note the restrictions:
+// * The type of first parameter have to be `TfLiteContext*`.
+// * All paramteters must be trivailly destructible. (E.g. No C++ class)
+TfLiteStatus ForbiddenContextFunction(TfLiteContext* context, ...) {
+  context->ReportError(context,
+                       "The function is forbidden if not calling in delegate.");
+  return kTfLiteError;
+}
+
+// Set the ForbiddenContextFunction to a compatible function pointer.
+template <typename FunctionType>
+void SetForbiddenContextFunction(FunctionType* func) {
+  *func = reinterpret_cast<FunctionType>(ForbiddenContextFunction);
+}
+
+}  // namespace
+
 // A trivial implementation of GraphInfo around the Interpreter.
 // NOTE: this interpreter info represents the subset of the
 // graph that is executed according to execution plan. Thus,
@@ -74,9 +95,9 @@ Interpreter::Interpreter(ErrorReporter* error_reporter)
   context_.gemm_context = nullptr;
 
   // Invalid to call these these except from TfLiteDelegate
-  context_.GetNodeAndRegistration = nullptr;
-  context_.ReplaceSubgraphsWithDelegateKernels = nullptr;
-  context_.GetExecutionPlan = nullptr;
+  SetForbiddenContextFunction(&context_.GetNodeAndRegistration);
+  SetForbiddenContextFunction(&context_.ReplaceSubgraphsWithDelegateKernels);
+  SetForbiddenContextFunction(&context_.GetExecutionPlan);
 
   // Reserve some space for the tensors to avoid excessive resizing.
   tensors_.reserve(kTensorsReservedCapacity);
@@ -686,9 +707,9 @@ TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) {
 
   TfLiteStatus status = delegate->Prepare(&context_, delegate);
   // Remove additional context info.
-  context_.GetNodeAndRegistration = nullptr;
-  context_.ReplaceSubgraphsWithDelegateKernels = nullptr;
-  context_.GetExecutionPlan = nullptr;
+  SetForbiddenContextFunction(&context_.GetNodeAndRegistration);
+  SetForbiddenContextFunction(&context_.ReplaceSubgraphsWithDelegateKernels);
+  SetForbiddenContextFunction(&context_.GetExecutionPlan);
   return status;
 }
 
index 2586c15..17eb2f4 100644 (file)
@@ -561,6 +561,46 @@ TEST(BasicInterpreter, TestCustomErrorReporter) {
   ASSERT_EQ(reporter.calls, 1);
 }
 
+TEST(BasicInterpreter, TestUnsupportedDelegateFunctions) {
+  Interpreter interpreter;
+  ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk);
+  TfLiteRegistration registration = {
+      .init = nullptr, .free = nullptr, .prepare = nullptr, .invoke = nullptr};
+  // These functions are only supported inside Delegate's Prepare function.
+  // The test verifies that these functions returns `kTfLiteError`, but not
+  // `kTfLiteOk` or just crashes.
+  registration.prepare = [](TfLiteContext* context, TfLiteNode* node) {
+    {
+      TfLiteIntArray* execution_plan;
+      EXPECT_EQ(context->GetExecutionPlan(context, &execution_plan),
+                kTfLiteError);
+    }
+    {
+      TfLiteNode* node;
+      TfLiteRegistration* registration;
+      EXPECT_EQ(
+          context->GetNodeAndRegistration(context, 0, &node, &registration),
+          kTfLiteError);
+    }
+    {
+      TfLiteRegistration delegate_registration = {nullptr, nullptr, nullptr,
+                                                  nullptr};
+      TfLiteIntArray nodes_to_replace;
+      nodes_to_replace.size = 0;
+      EXPECT_EQ(context->ReplaceSubgraphsWithDelegateKernels(
+                    context, delegate_registration, &nodes_to_replace, nullptr),
+                kTfLiteError);
+    }
+    return kTfLiteError;
+  };
+  ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
+  ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk);
+  ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr,
+                                              &registration),
+            kTfLiteOk);
+  EXPECT_EQ(interpreter.AllocateTensors(), kTfLiteError);
+}
+
 TEST(InterpreterTensorsCapacityTest, TestWithinHeadroom) {
   Interpreter interpreter;
   ASSERT_EQ(interpreter.AddTensors(Interpreter::kTensorsReservedCapacity),