deps = [
":framework",
":string_util",
+ "//tensorflow/contrib/lite/kernels/internal:tensor_utils",
+ "//tensorflow/contrib/lite/schema:schema_fbs",
"//tensorflow/contrib/lite/testing:util",
"@com_google_googletest//:gtest",
],
typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus;
+// Forward declare so GetNode can use this is in Context.
+typedef struct _TfLiteRegistration TfLiteRegistration;
+
#define kOptionalTensor (-1)
// Fixed size list of integers. Used for dimensions and inputs/outputs tensor
// Resize the allocated data of a (dynamic) tensor.
void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor);
+// A structure representing an instance of a node.
+// This structure only exhibits the inputs, outputs and user defined data, not
+// other features like the type.
+typedef struct {
+ // Inputs to this node expressed as indices into the simulator's tensors.
+ TfLiteIntArray* inputs;
+
+ // Outputs to this node expressed as indices into the simulator's tensors.
+ TfLiteIntArray* outputs;
+
+ // Temporary tensors uses during the computations. This usually contains no
+ // tensors, but ops are allowed to change that if they need scratch space of
+ // any sort.
+ TfLiteIntArray* temporaries;
+
+ // Opaque data provided by the node implementer through `Registration.init`.
+ void* user_data;
+
+ // Opaque data provided to the node if the node is a builtin. This is usually
+ // a structure defined in builtin_op_data.h
+ void* builtin_data;
+
+ // Custom initial data. This is the opaque data provided in the flatbuffer.
+ // WARNING: This is an experimental interface that is subject to change.
+ const void* custom_initial_data;
+ int custom_initial_data_size;
+} TfLiteNode;
+
typedef struct TfLiteContext {
// Number of tensors in the context.
int tensors_size;
+
+ // The execution plan contains a list of the node indices in execution
+ // order. execution_plan->size is the current number of nodes. And,
+ // execution_plan->data[0] is the first node that needs to be run.
+ // TfLiteDelegates can traverse the current execution plan by iterating
+ // through each member of this array and using GetNodeAndRegistration() to
+ // access details about a node. i.e.
+ // TfLiteIntArray* execution_plan;
+ // TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &execution_plan));
+ // for (int exec_index = 0; exec_index < execution_plan->size; exec_index++) {
+ // int node_index = execution_plan->data[exec_index];
+ // TfLiteNode* node;
+ // TfLiteRegistration* reg;
+ // context->GetNodeAndRegistration(context, node_index, &node, ®);
+ // }
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteStatus (*GetExecutionPlan)(struct TfLiteContext* context,
+ TfLiteIntArray** execution_plan);
+
// An tensor of tensors in the interpreter context (of length `tensors_size`)
TfLiteTensor* tensors;
TfLiteStatus (*AddTensors)(struct TfLiteContext*, int tensors_to_add,
int* first_new_tensor_index);
+ // Get a Tensor node by node_index.
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteStatus (*GetNodeAndRegistration)(struct TfLiteContext*, int node_index,
+ TfLiteNode** node,
+ TfLiteRegistration** registration);
+
+ // Replace ops with delegate.
+ TfLiteStatus (*ReplaceSubgraphsWithDelegateKernels)(
+ struct TfLiteContext*, TfLiteRegistration registration,
+ const TfLiteIntArray* nodes_to_replace);
+
// TODO(ahentz): we should create a more general mechanism for this sort of
// library-global objects.
void* gemm_context;
} TfLiteContext;
-// A structure representing an instance of a node.
-// This structure only exhibits the inputs, outputs and user defined data, not
-// other features like the type.
-typedef struct {
- // Inputs to this node expressed as indices into the simulator's tensors.
- TfLiteIntArray* inputs;
-
- // Outputs to this node expressed as indices into the simulator's tensors.
- TfLiteIntArray* outputs;
-
- // Temporary tensors uses during the computations. This usually contains no
- // tensors, but ops are allowed to change that if they need scratch space of
- // any sort.
- TfLiteIntArray* temporaries;
-
- // Opaque data provided by the node implementer through `Registration.init`.
- void* user_data;
-
- // Opaque data provided to the node if the node is a builtin.
- void* builtin_data;
-} TfLiteNode;
-
-typedef struct {
+typedef struct _TfLiteRegistration {
// Initializes the op from serialized data.
// If a built-in op:
// `buffer` is the op's params data (TfLiteLSTMParams*).
// NN API. Note, it is the responsibility of the registration binder to
// set this properly.
int32_t builtin_code;
+
+ // Custom op name. If the op is a builtin, this will be null.
+ // WARNING: This is an experimental interface that is subject to change.
+ const char* custom_name;
} TfLiteRegistration;
+// WARNING: This is an experimental interface that is subject to change.
+typedef struct {
+ // Data that delegate needs to identify itself. This data is owned by the
+ // delegate. The delegate is owned in the user code, so the delegate is
+ // responsible for doing this when it is destroyed.
+ void* data_;
+ // Invoked by ModifyGraphWithDelegate. This prepare is called, giving the
+ // delegate a view of the current graph through TfLiteContext*. It typically
+ // will look at the nodes and call ReplaceSubgraphsWithDelegateKernels()
+ // to ask the TensorFlow lite runtime to create macro-nodes to represent
+ // delegated subgraphs of the original graph.
+ TfLiteStatus (*Prepare)(TfLiteContext* context, void* data);
+} TfLiteDelegate;
+
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
context_.tensors = nullptr;
context_.tensors_size = 0;
context_.gemm_context = nullptr;
+
+ // Invalid to call these these except from TfLiteDelegate
+ context_.GetNodeAndRegistration = nullptr;
+ context_.ReplaceSubgraphsWithDelegateKernels = nullptr;
+ context_.GetExecutionPlan = nullptr;
+
// Reserve some space for the tensors to avoid excessive resizing.
tensors_.reserve(kSlotsToReserve);
nodes_and_registration_.reserve(kSlotsToReserve);
}
}
+TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels(
+ TfLiteContext* context, TfLiteRegistration registration,
+ const TfLiteIntArray* nodes_to_replace) {
+ return static_cast<Interpreter*>(context->impl_)
+ ->ReplaceSubgraphsWithDelegateKernels(registration, nodes_to_replace);
+}
+
+TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels(
+ TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace) {
+ // Analyze the graph to find all independent subgraphs that are either
+ // fully not-this-delegate or this-delegate computation.
+ InterpreterInfo info(this);
+ std::vector<Subgraph> subgraphs;
+ PartitionGraphIntoIndependentSubgraphs(&info, nodes_to_replace, &subgraphs);
+
+ execution_plan_.clear();
+ for (auto& subgraph : subgraphs) {
+ // Turn subgraph.nodes into a TfLiteIntArray compatible data structure.
+ // TODO(aselle): Avoid this copy by constructing subgraph.nodes that way
+ // in the first place
+ subgraph.nodes.insert(subgraph.nodes.begin(),
+ static_cast<int>(subgraph.nodes.size()));
+ // Subgraphs calimed by the delegate should have a "macro" op created, the
+ // other subgraphs (kTfNonPartition) just have their nodes added back to
+ // the execution plan.
+ switch (subgraph.type) {
+ case Subgraph::kTfNonPartition:
+ for (auto it = subgraph.nodes.begin() + 1; it != subgraph.nodes.end();
+ ++it) {
+ execution_plan_.push_back(*it);
+ }
+ break;
+ case Subgraph::kTfPartition: {
+ void* builtin_data = nullptr;
+ int node_index;
+ // Create a node that represents computation of this subgraph.
+ AddNodeWithParameters(
+ subgraph.input_tensors, subgraph.output_tensors,
+ reinterpret_cast<const char*>(subgraph.nodes.data()),
+ subgraph.nodes.size() * sizeof(subgraph.nodes[0]), builtin_data,
+ ®istration, &node_index);
+ } break;
+ case Subgraph::kTfUnexplored:
+ return kTfLiteError;
+ break;
+ }
+ }
+ return kTfLiteOk;
+}
+
+// Gets an TfLiteIntArray* representing the execution plan. The interpreter owns
+// this memory and it is only guaranteed to exist during the invocation of the
+// delegate prepare.
+TfLiteStatus Interpreter::GetExecutionPlan(TfLiteIntArray** execution_plan) {
+ // TODO(aselle): Do not make a copy here
+ plan_cache_.reset(TfLiteIntArrayCreate(execution_plan_.size()));
+ *execution_plan = plan_cache_.get();
+ static_assert(sizeof(plan_cache_->data[0]) == sizeof(execution_plan_[0]),
+ "TfLiteIntArray and execution_plan do not contain same type.");
+ memcpy(plan_cache_->data, execution_plan_.data(),
+ sizeof(plan_cache_->data[0]));
+ return kTfLiteOk;
+}
+
+// WARNING: This is an experimental interface that is subject to change.
+// Entry point for C node plugin API to get the execution plan
+TfLiteStatus Interpreter::GetExecutionPlan(struct TfLiteContext* context,
+ TfLiteIntArray** execution_plan) {
+ return static_cast<Interpreter*>(context->impl_)
+ ->GetExecutionPlan(execution_plan);
+}
+
TfLiteStatus Interpreter::SetInputs(std::vector<int> inputs) {
TF_LITE_ENSURE_OK(&context_,
CheckTensorIndices("inputs", inputs.data(), inputs.size()));
int new_node_index = nodes_and_registration_.size();
if (node_index) *node_index = new_node_index;
nodes_and_registration_.resize(nodes_and_registration_.size() + 1);
+
auto& node_and_reg = nodes_and_registration_.back();
TfLiteNode& node = node_and_reg.first;
if (node.inputs) TfLiteIntArrayFree(node.inputs);
->AddTensors(tensors_to_add, first_new_tensor_index);
}
+TfLiteStatus Interpreter::GetNodeAndRegistration(
+ int node_index, TfLiteNode** node, TfLiteRegistration** registration) {
+ TF_LITE_ENSURE(&context_, node_index < nodes_size() && node_index >= 0);
+ TF_LITE_ENSURE(&context_, node != nullptr && registration != nullptr);
+ *node = &nodes_and_registration_[node_index].first;
+ *registration = &nodes_and_registration_[node_index].second;
+ return kTfLiteOk;
+}
+
+TfLiteStatus Interpreter::GetNodeAndRegistration(
+ struct TfLiteContext* context, int node_index, TfLiteNode** node,
+ TfLiteRegistration** registration) {
+ return static_cast<Interpreter*>(context->impl_)
+ ->GetNodeAndRegistration(node_index, node, registration);
+}
+
TfLiteStatus Interpreter::SetTensorParametersReadOnly(
int tensor_index, TfLiteType type, const char* name,
const std::vector<int>& dims, TfLiteQuantizationParams quantization,
tflite::gemm_support::SetMaxNumThreads(&context_, num_threads);
}
+TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) {
+ // TODO(aselle): Consider if it is worth storing pointers to delegates.
+ // Setup additional context interface
+ context_.GetNodeAndRegistration = GetNodeAndRegistration;
+ context_.ReplaceSubgraphsWithDelegateKernels =
+ ReplaceSubgraphsWithDelegateKernels;
+ context_.GetExecutionPlan = GetExecutionPlan;
+
+ TfLiteStatus status = delegate->Prepare(&context_, delegate->data_);
+ // Remove additional context info.
+ context_.GetNodeAndRegistration = nullptr;
+ context_.ReplaceSubgraphsWithDelegateKernels = nullptr;
+ context_.GetExecutionPlan = nullptr;
+ return status;
+}
+
} // namespace tflite
// foo.Invoke();
//
+struct TfLiteIntArrayDeleter {
+ void operator()(TfLiteIntArray* a) {
+ if (a) TfLiteIntArrayFree(a);
+ }
+};
+
class Interpreter {
public:
// Instantiate an interpreter. All errors associated with reading and
// Set the number of threads available to the interpreter.
void SetNumThreads(int num_threads);
+ // Allow a delegate to look at the graph and modify the graph to handle
+ // parts of the graph themselves. After this is called, the graph may
+ // contain new nodes that replace 1 more nodes.
+ TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegate* delegate);
+
private:
// Give 'op_reg' a chance to initialize itself using the contents of
// 'buffer'.
static TfLiteStatus AddTensors(TfLiteContext* context, int tensors_to_add,
int* first_new_tensor_index);
+ // WARNING: This is an experimental API and subject to change.
+ // Entry point for C API ReplaceSubgraphsWithDelegateKernels
+ static TfLiteStatus ReplaceSubgraphsWithDelegateKernels(
+ TfLiteContext* context, TfLiteRegistration registration,
+ const TfLiteIntArray* nodes_to_replace);
+
+ // Update the execution graph to replace some of the nodes with stub
+ // nodes. Specifically any node index that has `nodes[index]==1` will be
+ // slated for replacement with a delegate kernel specified by registration.
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteStatus ReplaceSubgraphsWithDelegateKernels(
+ TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace);
+
+ // WARNING: This is an experimental interface that is subject to change.
+ // Gets the internal pointer to a TensorFlow lite node by node_index.
+ TfLiteStatus GetNodeAndRegistration(int node_index, TfLiteNode** node,
+ TfLiteRegistration** registration);
+
+ // WARNING: This is an experimental interface that is subject to change.
+ // Entry point for C node plugin API to get a node by index.
+ static TfLiteStatus GetNodeAndRegistration(struct TfLiteContext*,
+ int node_index, TfLiteNode** node,
+ TfLiteRegistration** registration);
+
+ // WARNING: This is an experimental interface that is subject to change.
+ // Gets an TfLiteIntArray* representing the execution plan. The caller owns
+ // this memory and must free it with TfLiteIntArrayFree().
+ TfLiteStatus GetExecutionPlan(TfLiteIntArray** execution_plan);
+
+ // WARNING: This is an experimental interface that is subject to change.
+ // Entry point for C node plugin API to get the execution plan
+ static TfLiteStatus GetExecutionPlan(struct TfLiteContext* context,
+ TfLiteIntArray** execution_plan);
+
// A pure C data structure used to communicate with the pure C plugin
// interface. To avoid copying tensor metadata, this is also the definitive
// structure to store tensors.
// subset of the node indices.
std::vector<int> execution_plan_;
+ // In the future, we'd like a TfLiteIntArray compatible representation.
+ // TODO(aselle): replace execution_plan_ with this.
+ std::unique_ptr<TfLiteIntArray, TfLiteIntArrayDeleter> plan_cache_;
+
// Whether to delegate to NN API
std::unique_ptr<NNAPIDelegate> nnapi_delegate_;
#include "tensorflow/contrib/lite/interpreter.h"
#include <gtest/gtest.h>
#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
#include "tensorflow/contrib/lite/string_util.h"
#include "tensorflow/contrib/lite/testing/util.h"
-
namespace tflite {
namespace {
ASSERT_EQ(run_order_, std::vector<int>());
}
+// Build a kernel registration for an op that copies its one input
+// to an output
+TfLiteRegistration AddOpRegistration() {
+ TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
+
+ reg.custom_name = "my_add";
+ reg.builtin_code = tflite::BuiltinOperator_CUSTOM;
+
+ reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
+ // Set output size to input size
+ TfLiteTensor* tensor0 = &context->tensors[node->inputs->data[0]];
+ TfLiteTensor* tensor1 = &context->tensors[node->inputs->data[1]];
+ TfLiteTensor* tensor2 = &context->tensors[node->outputs->data[0]];
+ TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims);
+ TfLiteIntArray* newSizeOther = TfLiteIntArrayCopy(tensor1->dims);
+ TF_LITE_ENSURE_EQ(context, newSize->size, newSizeOther->size);
+ TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, tensor2, newSize));
+ return kTfLiteOk;
+ };
+
+ reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
+ // Copy input data to output data.
+ TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]];
+ TfLiteTensor* a1 = &context->tensors[node->inputs->data[1]];
+ TfLiteTensor* out = &context->tensors[node->outputs->data[0]];
+ int num = a0->dims->data[0];
+ for (int i = 0; i < num; i++) {
+ out->data.f[i] = a0->data.f[i] + a1->data.f[i];
+ }
+ return kTfLiteOk;
+ };
+ return reg;
+}
+
+class TestDelegate : public ::testing::Test {
+ public:
+ TestDelegate() {
+ interpreter_.AddTensors(5);
+ interpreter_.SetInputs({0, 1});
+ interpreter_.SetOutputs({3, 4});
+ TfLiteQuantizationParams quant;
+ interpreter_.SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3},
+ quant);
+ interpreter_.SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3},
+ quant);
+ interpreter_.SetTensorParametersReadWrite(2, kTfLiteFloat32, "", {3},
+ quant);
+ interpreter_.SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {3},
+ quant);
+ TfLiteRegistration reg = AddOpRegistration();
+ interpreter_.AddNodeWithParameters({0, 0}, {2}, nullptr, 0, nullptr, ®);
+ interpreter_.AddNodeWithParameters({1, 1}, {3}, nullptr, 0, nullptr, ®);
+ interpreter_.AddNodeWithParameters({2, 1}, {4}, nullptr, 0, nullptr, ®);
+ }
+
+ protected:
+ class SimpleDelegate {
+ public:
+ // Create a simple implementation of a TfLiteDelegate. We use the C++ class
+ // SimpleDelegate and it can produce a handle TfLiteDelegate that is
+ // value-copyable and compatible with TfLite.
+ explicit SimpleDelegate(const std::vector<int>& nodes) : nodes_(nodes) {
+ delegate_.Prepare = [](TfLiteContext* context,
+ void* data) -> TfLiteStatus {
+ auto* simple = reinterpret_cast<SimpleDelegate*>(data);
+ TfLiteIntArray* nodes_to_separate =
+ TfLiteIntArrayCreate(simple->nodes_.size());
+ // Mark nodes that we want in TfLiteIntArray* structure.
+ int index = 0;
+ for (auto node_index : simple->nodes_) {
+ nodes_to_separate->data[index++] = node_index;
+ // make sure node is add
+ TfLiteNode* node;
+ TfLiteRegistration* reg;
+ context->GetNodeAndRegistration(context, node_index, &node, ®);
+ TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_CUSTOM);
+ TFLITE_CHECK_EQ(strcmp(reg->custom_name, "my_add"), 0);
+ }
+ // Check that all nodes are available
+ TfLiteIntArray* execution_plan;
+ TF_LITE_ENSURE_STATUS(
+ context->GetExecutionPlan(context, &execution_plan));
+ for (int exec_index = 0; exec_index < execution_plan->size;
+ exec_index++) {
+ int node_index = execution_plan->data[exec_index];
+ TfLiteNode* node;
+ TfLiteRegistration* reg;
+ context->GetNodeAndRegistration(context, node_index, &node, ®);
+ TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_CUSTOM);
+ TFLITE_CHECK_EQ(strcmp(reg->custom_name, "my_add"), 0);
+ }
+
+ context->ReplaceSubgraphsWithDelegateKernels(
+ context, FakeFusedRegistration(), nodes_to_separate);
+ TfLiteIntArrayFree(nodes_to_separate);
+ return kTfLiteOk;
+ };
+ // Store type-punned data SimpleDelegate structure.
+ delegate_.data_ = reinterpret_cast<void*>(this);
+ }
+
+ static TfLiteRegistration FakeFusedRegistration() {
+ TfLiteRegistration reg = {nullptr};
+ reg.custom_name = "fake_fused_op";
+ return reg;
+ }
+
+ TfLiteDelegate* get_tf_lite_delegate() { return &delegate_; }
+
+ private:
+ std::vector<int> nodes_;
+ TfLiteDelegate delegate_;
+ };
+ Interpreter interpreter_;
+};
+
+TEST_F(TestDelegate, BasicDelegate) {
+ interpreter_.Invoke();
+ SimpleDelegate simple({0, 1, 2});
+ interpreter_.ModifyGraphWithDelegate(simple.get_tf_lite_delegate());
+
+ ASSERT_EQ(interpreter_.execution_plan().size(), 1);
+ int node = interpreter_.execution_plan()[0];
+ const auto* node_and_reg = interpreter_.node_and_registration(node);
+ ASSERT_EQ(node_and_reg->second.custom_name,
+ SimpleDelegate::FakeFusedRegistration().custom_name);
+}
+
+TEST_F(TestDelegate, ComplexDeligate) {
+ interpreter_.Invoke();
+ SimpleDelegate simple({1, 2});
+ interpreter_.ModifyGraphWithDelegate(simple.get_tf_lite_delegate());
+
+ ASSERT_EQ(interpreter_.execution_plan().size(), 2);
+ // 0th should be a non-delegated original op
+ ASSERT_EQ(interpreter_.execution_plan()[0], 0);
+ // 1st should be a new macro op (3) which didn't exist)
+ ASSERT_EQ(interpreter_.execution_plan()[1], 3);
+ const auto* node_and_reg = interpreter_.node_and_registration(3);
+ ASSERT_EQ(node_and_reg->second.custom_name,
+ SimpleDelegate::FakeFusedRegistration().custom_name);
+}
+
} // namespace
} // namespace tflite