}
TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors());
- invokable_ = true;
+ if (state_ == kStateUninvokable) {
+ state_ = kStateInvokable;
+ }
+ TF_LITE_ENSURE(&context_, state_ == kStateInvokable ||
+ state_ == kStateInvokableAndImmutable);
return kTfLiteOk;
}
const std::vector<int>& inputs, const std::vector<int>& outputs,
const char* init_data, size_t init_data_size, void* builtin_data,
const TfLiteRegistration* registration, int* node_index) {
- invokable_ = false;
+ if (state_ == kStateInvokableAndImmutable) {
+ ReportError(&context_,
+ "AddNodeWithParameters is disallowed when graph is immutable.");
+ return kTfLiteError;
+ }
+ state_ = kStateUninvokable;
std::unique_ptr<void, decltype(free)*> builtin_data_deleter(builtin_data,
free);
TfLiteStatus Interpreter::ResizeInputTensor(int tensor_index,
const std::vector<int>& dims) {
+ if (state_ == kStateInvokableAndImmutable) {
+ ReportError(&context_,
+ "ResizeInputTensor is disallowed when graph is immutable.");
+ return kTfLiteError;
+ }
+ state_ = kStateUninvokable;
+
// TODO(aselle): All bounds checks can be implemented as one-sided bounds
// checks by casting to unsigned for efficiency. Profile before doing this.
-
TF_LITE_ENSURE(&context_,
tensor_index < context_.tensors_size && tensor_index >= 0);
- invokable_ = false;
TfLiteIntArray* dims_lite = ConvertVectorToTfLiteIntArray(dims);
return ResizeTensorImpl(&context_.tensors[tensor_index], dims_lite);
}
ReportError(&context_, "Invoke called on model that is not consistent.");
return kTfLiteError;
}
- if (!invokable_) {
+ if (state_ == kStateUninvokable) {
ReportError(&context_, "Invoke called on model that is not ready.");
return kTfLiteError;
}
int tensor_index, TfLiteType type, const char* name, const int rank,
const int* dims, TfLiteQuantizationParams quantization, const char* buffer,
size_t bytes, const Allocation* allocation) {
+ if (state_ == kStateInvokableAndImmutable) {
+ ReportError(
+ &context_,
+ "SetTensorParametersReadOnly is disallowed when graph is immutable.");
+ return kTfLiteError;
+ }
+
TF_LITE_ENSURE(&context_,
tensor_index < context_.tensors_size && tensor_index >= 0);
// For most tensors we know exactly how much memory is necessary so we can
tensor.allocation_type = kTfLiteMmapRo;
tensor.allocation = allocation;
} else {
- invokable_ = false;
+ state_ = kStateUninvokable;
TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims),
quantization, const_cast<char*>(buffer), bytes,
kTfLiteMmapRo, allocation, &tensor);
TfLiteStatus Interpreter::SetTensorParametersReadWrite(
int tensor_index, TfLiteType type, const char* name, const int rank,
const int* dims, TfLiteQuantizationParams quantization) {
- invokable_ = false;
+ if (state_ == kStateInvokableAndImmutable) {
+ ReportError(
+ &context_,
+ "SetTensorParametersReadWrite is disallowed when graph is immutable.");
+ return kTfLiteError;
+ }
TF_LITE_ENSURE(&context_,
tensor_index < context_.tensors_size && tensor_index >= 0);
size_t required_bytes = 0;
context_.recommended_num_threads = num_threads;
}
-TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) {
+TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate,
+ bool allow_dynamic_tensors) {
+ if (!allow_dynamic_tensors) {
+ int last_execution_plan_index_prepared;
+ TF_LITE_ENSURE_OK(&context_, PrepareOpsStartingAt(
+ 0, &last_execution_plan_index_prepared));
+
+ bool has_dynamic_tensors = true;
+ // Dynamic tensors exist if not all nodes can be prepared.
+ if (last_execution_plan_index_prepared + 1 == execution_plan_.size()) {
+ // If all the nodes can be prepared, check if the last node has dynamic
+ // tensors.
+ int node_index = execution_plan_[last_execution_plan_index_prepared];
+ TfLiteNode& node = nodes_and_registration_[node_index].first;
+ if (!HasDynamicTensor(context_, node.outputs)) {
+ has_dynamic_tensors = false;
+ }
+ }
+ if (has_dynamic_tensors) {
+ ReportError(&context_, "Attempting to resize a fixed-size tensor.");
+ return kTfLiteError;
+ }
+ }
+
// TODO(aselle): Consider if it is worth storing pointers to delegates.
- // Setup additional context interface
+ // Setup additional context interface.
context_.GetNodeAndRegistration = GetNodeAndRegistration;
context_.ReplaceSubgraphsWithDelegateKernels =
ReplaceSubgraphsWithDelegateKernels;
context_.GetExecutionPlan = GetExecutionPlan;
TfLiteStatus status = delegate->Prepare(&context_, delegate);
+
// Remove additional context info.
SetForbiddenContextFunction(&context_.GetNodeAndRegistration);
SetForbiddenContextFunction(&context_.ReplaceSubgraphsWithDelegateKernels);
SetForbiddenContextFunction(&context_.GetExecutionPlan);
+
+ TF_LITE_ENSURE_OK(&context_, status);
+
+ if (!allow_dynamic_tensors) {
+ TF_LITE_ENSURE_OK(&context_, AllocateTensors());
+ TF_LITE_ENSURE(&context_, state_ == kStateInvokable ||
+ state_ == kStateInvokableAndImmutable);
+ // After using a delegate which doesn't support dynamic tensors, make the
+ // entire graph immutable.
+ state_ = kStateInvokableAndImmutable;
+ }
+
return status;
}
#include <gtest/gtest.h>
#include "tensorflow/contrib/lite/error_reporter.h"
#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/schema/schema_generated.h"
#include "tensorflow/contrib/lite/string_util.h"
#include "tensorflow/contrib/lite/testing/util.h"
+
namespace tflite {
namespace {
// String-in String-out node.
TfLiteRegistration reg_copy = {nullptr, nullptr, nullptr, nullptr};
reg_copy.invoke = [](TfLiteContext* context, TfLiteNode* node) {
- TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]];
- TfLiteTensor* a1 = &context->tensors[node->outputs->data[0]];
+ TfLiteTensor* input = &context->tensors[node->inputs->data[0]];
+ TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
DynamicBuffer buf;
- StringRef str_ref = GetString(a0, 0);
+ StringRef str_ref = GetString(input, 0);
buf.AddString(str_ref);
- buf.WriteToTensor(a1);
+ buf.WriteToTensor(output);
return kTfLiteOk;
};
reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
// Set output size to input size
- TfLiteTensor* tensor0 = &context->tensors[node->inputs->data[0]];
- TfLiteTensor* tensor1 = &context->tensors[node->inputs->data[1]];
- TfLiteTensor* tensor2 = &context->tensors[node->outputs->data[0]];
- TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims);
- TfLiteIntArray* newSizeOther = TfLiteIntArrayCopy(tensor1->dims);
- TF_LITE_ENSURE_EQ(context, newSize->size, newSizeOther->size);
- TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, tensor2, newSize));
+ TfLiteTensor* input1 = &context->tensors[node->inputs->data[0]];
+ TfLiteTensor* input2 = &context->tensors[node->inputs->data[1]];
+ TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
+
+ TF_LITE_ENSURE_EQ(context, input1->dims->size, input2->dims->size);
+ for (int i = 0; i < input1->dims->size; ++i) {
+ TF_LITE_ENSURE_EQ(context, input1->dims->data[i], input2->dims->data[i]);
+ }
+
+ TF_LITE_ENSURE_STATUS(context->ResizeTensor(
+ context, output, TfLiteIntArrayCopy(input1->dims)));
return kTfLiteOk;
};
quant);
interpreter_->SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {3},
quant);
+ interpreter_->SetTensorParametersReadWrite(4, kTfLiteFloat32, "", {3},
+ quant);
TfLiteRegistration reg = AddOpRegistration();
interpreter_->AddNodeWithParameters({0, 0}, {2}, nullptr, 0, nullptr, ®);
interpreter_->AddNodeWithParameters({1, 1}, {3}, nullptr, 0, nullptr, ®);
};
TEST_F(TestDelegate, BasicDelegate) {
- interpreter_->Invoke();
delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate());
}
TEST_F(TestDelegate, ComplexDeligate) {
- interpreter_->Invoke();
delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({1, 2}));
interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate());
}
TEST_F(TestDelegate, SetBufferHandleToInput) {
- interpreter_->Invoke();
delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate();
interpreter_->ModifyGraphWithDelegate(delegate);
}
TEST_F(TestDelegate, SetBufferHandleToOutput) {
- interpreter_->Invoke();
delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate();
interpreter_->ModifyGraphWithDelegate(delegate);
interpreter_->Invoke();
delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate();
- interpreter_->ModifyGraphWithDelegate(delegate);
+ interpreter_->ModifyGraphWithDelegate(delegate, true);
SimpleDelegate another_simple_delegate({0, 1, 2});
EXPECT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle);
}
+TEST_F(TestDelegate, ResizeInputWithNonDynamicDelegateShouldFail) {
+ delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2}));
+ ASSERT_EQ(interpreter_->ResizeInputTensor(0, {1, 2}), kTfLiteOk);
+ ASSERT_EQ(interpreter_->ResizeInputTensor(1, {1, 2}), kTfLiteOk);
+ ASSERT_EQ(
+ interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()),
+ kTfLiteOk);
+ ASSERT_EQ(interpreter_->ResizeInputTensor(0, {1, 2}), kTfLiteError);
+}
+
+class TestDelegateWithDynamicTensors : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ interpreter_.reset(new Interpreter);
+
+ interpreter_->AddTensors(2);
+ interpreter_->SetInputs({0});
+ interpreter_->SetOutputs({1});
+ TfLiteQuantizationParams quant;
+ interpreter_->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3},
+ quant);
+ interpreter_->SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3},
+ quant);
+ TfLiteRegistration reg = DynamicCopyOpRegistration();
+ interpreter_->AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, ®);
+
+ delegate_.Prepare = [](TfLiteContext* context,
+ TfLiteDelegate* delegate) -> TfLiteStatus {
+ // In this test, the delegate replaces all the nodes if this function is
+ // called.
+ TfLiteIntArray* execution_plan;
+ TF_LITE_ENSURE_STATUS(
+ context->GetExecutionPlan(context, &execution_plan));
+ context->ReplaceSubgraphsWithDelegateKernels(
+ context, DelegateRegistration(), execution_plan, delegate);
+ return kTfLiteOk;
+ };
+ }
+
+ static TfLiteRegistration DynamicCopyOpRegistration() {
+ TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
+
+ reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
+ SetTensorToDynamic(output);
+ return kTfLiteOk;
+ };
+
+ reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
+ // Not implemented since this isn't required in testing.
+ return kTfLiteOk;
+ };
+ return reg;
+ }
+
+ static TfLiteRegistration DelegateRegistration() {
+ TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
+ return reg;
+ }
+
+ std::unique_ptr<Interpreter> interpreter_;
+ TfLiteDelegate delegate_;
+};
+
+TEST_F(TestDelegateWithDynamicTensors, DisallowDynamicTensors) {
+ interpreter_->ModifyGraphWithDelegate(&delegate_, false);
+
+ ASSERT_EQ(interpreter_->execution_plan().size(), 1);
+ // The interpreter should not call delegate's `Prepare` when dynamic tensors
+ // exist. So the node ID isn't changed.
+ ASSERT_EQ(interpreter_->execution_plan()[0], 0);
+}
+
+TEST_F(TestDelegateWithDynamicTensors, AllowDynamicTensors) {
+ interpreter_->ModifyGraphWithDelegate(&delegate_, true);
+
+ ASSERT_EQ(interpreter_->execution_plan().size(), 1);
+ // The node should be replaced because dynamic tensors are allowed. Therefore
+ // only node ID in the execution plan is changed from 0 to 1.
+ ASSERT_EQ(interpreter_->execution_plan()[0], 1);
+}
+
} // namespace
} // namespace tflite