[Runtime] EdgeTPU runtime for Coral Boards (#4698)
authorThierry Moreau <tmoreau@octoml.ai>
Thu, 16 Jan 2020 20:20:42 +0000 (12:20 -0800)
committerTianqi Chen <tqchen@users.noreply.github.com>
Thu, 16 Jan 2020 20:20:42 +0000 (12:20 -0800)
cmake/config.cmake
cmake/modules/contrib/TFLite.cmake
python/tvm/contrib/tflite_runtime.py
src/runtime/contrib/edgetpu/edgetpu_runtime.cc [new file with mode: 0644]
src/runtime/contrib/edgetpu/edgetpu_runtime.h [new file with mode: 0644]
src/runtime/contrib/tflite/tflite_runtime.cc
src/runtime/contrib/tflite/tflite_runtime.h
tests/python/contrib/test_edgetpu_runtime.py [new file with mode: 0644]
tests/python/contrib/test_tflite_runtime.py

index 7929cd2..40549ec 100644 (file)
@@ -154,6 +154,11 @@ set(USE_TFLITE OFF)
 # /path/to/tensorflow: tensorflow root path when use tflite library
 set(USE_TENSORFLOW_PATH none)
 
+# Possible values:
+# - OFF: disable tflite support for edgetpu
+# - /path/to/edgetpu: use specific path to edgetpu library
+set(USE_EDGETPU OFF)
+
 # Whether use CuDNN
 set(USE_CUDNN OFF)
 
index 9074def..ec03c96 100644 (file)
@@ -25,6 +25,15 @@ if(NOT USE_TFLITE STREQUAL "OFF")
   list(APPEND RUNTIME_SRCS ${TFLITE_CONTRIB_SRC})
   include_directories(${USE_TENSORFLOW_PATH})
 
+  # Additional EdgeTPU libs
+  if (NOT USE_EDGETPU STREQUAL "OFF")
+    message(STATUS "Build with contrib.edgetpu")
+    file(GLOB EDGETPU_CONTRIB_SRC src/runtime/contrib/edgetpu/*.cc)
+    list(APPEND RUNTIME_SRCS ${EDGETPU_CONTRIB_SRC})
+    include_directories(${USE_EDGETPU}/libedgetpu)
+    list(APPEND TVM_RUNTIME_LINKER_LIBS ${USE_EDGETPU}/libedgetpu/direct/aarch64/libedgetpu.so.1)
+  endif()
+
   if (USE_TFLITE STREQUAL "ON")
     set(USE_TFLITE ${USE_TENSORFLOW_PATH}/tensorflow/lite/tools/make/gen/*/lib)
   endif()
index 89a547f..5ff30a1 100644 (file)
@@ -18,7 +18,7 @@
 from .._ffi.function import get_global_func
 from ..rpc import base as rpc_base
 
-def create(tflite_model_bytes, ctx):
+def create(tflite_model_bytes, ctx, runtime_target='cpu'):
     """Create a runtime executor module given a tflite model and context.
     Parameters
     ----------
@@ -27,16 +27,25 @@ def create(tflite_model_bytes, ctx):
     ctx : TVMContext
         The context to deploy the module. It can be local or remote when there
         is only one TVMContext.
+    runtime_target: str
+        Execution target of TFLite runtime: either `cpu` or `edge_tpu`.
     Returns
     -------
     tflite_runtime : TFLiteModule
         Runtime tflite module that can be used to execute the tflite model.
     """
     device_type = ctx.device_type
+
+    if runtime_target == 'edge_tpu':
+        runtime_func = "tvm.edgetpu_runtime.create"
+    else:
+        runtime_func = "tvm.tflite_runtime.create"
+
     if device_type >= rpc_base.RPC_SESS_MASK:
-        fcreate = ctx._rpc_sess.get_function("tvm.tflite_runtime.create")
-        return TFLiteModule(fcreate(bytearray(tflite_model_bytes), ctx))
-    fcreate = get_global_func("tvm.tflite_runtime.create")
+        fcreate = ctx._rpc_sess.get_function(runtime_func)
+    else:
+        fcreate = get_global_func(runtime_func)
+
     return TFLiteModule(fcreate(bytearray(tflite_model_bytes), ctx))
 
 
@@ -50,12 +59,12 @@ class TFLiteModule(object):
     Parameters
     ----------
     module : Module
-        The interal tvm module that holds the actual tflite functions.
+        The internal tvm module that holds the actual tflite functions.
 
     Attributes
     ----------
     module : Module
-        The interal tvm module that holds the actual tflite functions.
+        The internal tvm module that holds the actual tflite functions.
     """
 
     def __init__(self, module):
@@ -63,7 +72,6 @@ class TFLiteModule(object):
         self._set_input = module["set_input"]
         self._invoke = module["invoke"]
         self._get_output = module["get_output"]
-        self._allocate_tensors = module["allocate_tensors"]
 
     def set_input(self, index, value):
         """Set inputs to the module via kwargs
@@ -91,12 +99,6 @@ class TFLiteModule(object):
         """
         self._invoke()
 
-    def allocate_tensors(self):
-        """Allocate space for all tensors.
-        """
-        self._allocate_tensors()
-
-
     def get_output(self, index):
         """Get index-th output to out
 
diff --git a/src/runtime/contrib/edgetpu/edgetpu_runtime.cc b/src/runtime/contrib/edgetpu/edgetpu_runtime.cc
new file mode 100644 (file)
index 0000000..4823ef7
--- /dev/null
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file edgetpu_runtime.cc
+ */
+#include <tvm/runtime/registry.h>
+#include <tensorflow/lite/interpreter.h>
+#include <tensorflow/lite/kernels/register.h>
+#include <tensorflow/lite/model.h>
+#include <edgetpu.h>
+
+
+#include "edgetpu_runtime.h"
+
+namespace tvm {
+namespace runtime {
+
+void EdgeTPURuntime::Init(const std::string& tflite_model_bytes,
+                          TVMContext ctx) {
+  const char* buffer = tflite_model_bytes.c_str();
+  size_t buffer_size = tflite_model_bytes.size();
+  // Load compiled model as a FlatBufferModel
+  std::unique_ptr<tflite::FlatBufferModel> model =
+    tflite::FlatBufferModel::BuildFromBuffer(buffer, buffer_size);
+  // Build resolver
+  tflite::ops::builtin::BuiltinOpResolver resolver;
+  // Init EdgeTPUContext object
+  edgetpu_context_ = edgetpu::EdgeTpuManager::GetSingleton()->OpenDevice();
+  // Add custom edgetpu ops to resolver
+  resolver.AddCustom(edgetpu::kCustomOp, edgetpu::RegisterCustomOp());
+  // Build interpreter
+  TfLiteStatus status = tflite::InterpreterBuilder(*model, resolver)(&interpreter_);
+  CHECK_TFLITE_STATUS(status) << "Failed to build interpreter.";
+  // Bind EdgeTPU context with interpreter.
+  interpreter_->SetExternalContext(kTfLiteEdgeTpuContext, edgetpu_context_.get());
+  interpreter_->SetNumThreads(1);
+  // Allocate tensors
+  status = interpreter_->AllocateTensors();
+  CHECK_TFLITE_STATUS(status) << "Failed to allocate tensors.";
+
+  ctx_ = ctx;
+}
+
+Module EdgeTPURuntimeCreate(const std::string& tflite_model_bytes,
+                           TVMContext ctx) {
+  auto exec = make_object<EdgeTPURuntime>();
+  exec->Init(tflite_model_bytes, ctx);
+  return Module(exec);
+}
+
+TVM_REGISTER_GLOBAL("tvm.edgetpu_runtime.create")
+  .set_body([](TVMArgs args, TVMRetValue* rv) {
+    *rv = EdgeTPURuntimeCreate(args[0], args[1]);
+  });
+}  // namespace runtime
+}  // namespace tvm
diff --git a/src/runtime/contrib/edgetpu/edgetpu_runtime.h b/src/runtime/contrib/edgetpu/edgetpu_runtime.h
new file mode 100644 (file)
index 0000000..78730d5
--- /dev/null
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \brief EdgeTPU runtime that can run tflite model compiled
+ *        for EdgeTPU containing only tvm PackedFunc.
+ * \file edgetpu_runtime.h
+ */
+#ifndef TVM_RUNTIME_CONTRIB_EDGETPU_EDGETPU_RUNTIME_H_
+#define TVM_RUNTIME_CONTRIB_EDGETPU_EDGETPU_RUNTIME_H_
+
+#include <string>
+#include <memory>
+
+#include "../tflite/tflite_runtime.h"
+
+namespace tvm {
+namespace runtime {
+
+/*!
+ * \brief EdgeTPU runtime.
+ *
+ *  This runtime can be accessed in various languages via
+ *  the TVM runtime PackedFunc API.
+ */
+class EdgeTPURuntime : public TFLiteRuntime {
+ public:
+  /*!
+   * \return The type key of the executor.
+   */
+  const char* type_key() const final {
+    return "EdgeTPURuntime";
+  }
+
+  /*!
+   * \brief Initialize the edge TPU tflite runtime with tflite model and context.
+   * \param tflite_model_bytes The tflite model.
+   * \param ctx The context where the tflite model will be executed on.
+   */
+  void Init(const std::string& tflite_model_bytes,
+            TVMContext ctx);
+
+ private:
+  std::shared_ptr<edgetpu::EdgeTpuContext> edgetpu_context_;
+};
+
+}  // namespace runtime
+}  // namespace tvm
+
+#endif  // TVM_RUNTIME_CONTRIB_EDGETPU_EDGETPU_RUNTIME_H_
index e249f35..56d3ce9 100644 (file)
@@ -21,7 +21,6 @@
  * \file tflite_runtime.cc
  */
 #include <tvm/runtime/registry.h>
-#include <tvm/dtype.h>
 #include <tensorflow/lite/interpreter.h>
 #include <tensorflow/lite/kernels/register.h>
 #include <tensorflow/lite/model.h>
@@ -33,37 +32,37 @@ namespace tvm {
 namespace runtime {
 
 #define TVM_DTYPE_DISPATCH(type, DType, ...)            \
-  if (type == DataType::Float(64)) {                              \
+  if (type == DataType::Float(64)) {                    \
     typedef double DType;                               \
     {__VA_ARGS__}                                       \
-  } else if (type == DataType::Float(32)) {                       \
+  } else if (type == DataType::Float(32)) {             \
     typedef float DType;                                \
     {__VA_ARGS__}                                       \
-  } else if (type == DataType::Float(16)) {                       \
+  } else if (type == DataType::Float(16)) {             \
     typedef uint16_t DType;                             \
     {__VA_ARGS__}                                       \
-  } else if (type == DataType::Int(64)) {                         \
+  } else if (type == DataType::Int(64)) {               \
     typedef int64_t DType;                              \
     {__VA_ARGS__}                                       \
-  } else if (type == DataType::Int(32)) {                         \
+  } else if (type == DataType::Int(32)) {               \
     typedef int32_t DType;                              \
     {__VA_ARGS__}                                       \
-  } else if (type == DataType::Int(16)) {                         \
+  } else if (type == DataType::Int(16)) {               \
     typedef int16_t DType;                              \
     {__VA_ARGS__}                                       \
-  } else if (type == DataType::Int(8)) {                          \
+  } else if (type == DataType::Int(8)) {                \
     typedef int8_t DType;                               \
     {__VA_ARGS__}                                       \
-  } else if (type == DataType::UInt(64)) {                        \
+  } else if (type == DataType::UInt(64)) {              \
     typedef uint64_t DType;                             \
     {__VA_ARGS__}                                       \
-  } else if (type == DataType::UInt(32)) {                        \
+  } else if (type == DataType::UInt(32)) {              \
     typedef uint32_t DType;                             \
     {__VA_ARGS__}                                       \
-  } else if (type == DataType::UInt(16)) {                        \
+  } else if (type == DataType::UInt(16)) {              \
     typedef uint16_t DType;                             \
     {__VA_ARGS__}                                       \
-  } else if (type == DataType::UInt(8)) {                         \
+  } else if (type == DataType::UInt(8)) {               \
     typedef uint8_t DType;                              \
     {__VA_ARGS__}                                       \
   } else {                                              \
@@ -79,9 +78,9 @@ DataType TfLiteDType2TVMDType(TfLiteType dtype) {
     case kTfLiteInt64:
       return DataType::Int(64);
     case kTfLiteInt16:
-      returnDataType::Int(16);
+      return DataType::Int(16);
     case kTfLiteInt8:
-      returnDataType::Int(8);
+      return DataType::Int(8);
     case kTfLiteUInt8:
       return DataType::UInt(8);
     case kTfLiteFloat16:
@@ -92,7 +91,6 @@ DataType TfLiteDType2TVMDType(TfLiteType dtype) {
   }
 }
 
-
 void TFLiteRuntime::Init(const std::string& tflite_model_bytes,
                          TVMContext ctx) {
   const char* buffer = tflite_model_bytes.c_str();
@@ -100,12 +98,14 @@ void TFLiteRuntime::Init(const std::string& tflite_model_bytes,
   std::unique_ptr<tflite::FlatBufferModel> model =
     tflite::FlatBufferModel::BuildFromBuffer(buffer, buffer_size);
   tflite::ops::builtin::BuiltinOpResolver resolver;
-  tflite::InterpreterBuilder(*model, resolver)(&interpreter_);
-  ctx_ = ctx;
-}
+  // Build interpreter
+  TfLiteStatus status = tflite::InterpreterBuilder(*model, resolver)(&interpreter_);
+  CHECK_TFLITE_STATUS(status) << "Failed to build interpreter.";
+  // Allocate tensors
+  status = interpreter_->AllocateTensors();
+  CHECK_TFLITE_STATUS(status) << "Failed to allocate tensors.";
 
-void TFLiteRuntime::AllocateTensors() {
-  interpreter_->AllocateTensors();
+  ctx_ = ctx;
 }
 
 void TFLiteRuntime::Invoke() {
@@ -129,7 +129,7 @@ void TFLiteRuntime::SetInput(int index, DLTensor* data_in) {
 }
 
 NDArray TFLiteRuntime::GetOutput(int index) const {
-  TfLiteTensor* output = interpreter_->output_tensor(index);
+  TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[index]);
   DataType dtype = TfLiteDType2TVMDType(output->type);
   TfLiteIntArray* dims = output->dims;
   int64_t size = 1;
@@ -167,10 +167,6 @@ PackedFunc TFLiteRuntime::GetFunction(
     return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
         this->Invoke();
       });
-  } else if (name == "allocate_tensors") {
-    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
-        this->AllocateTensors();
-      });
   } else {
     return PackedFunc();
   }
index 4b08b97..d823690 100644 (file)
 namespace tvm {
 namespace runtime {
 
+#define CHECK_TFLITE_STATUS(ret) CHECK_EQ(ret, kTfLiteOk)
 
 /*!
  * \brief Tflite runtime.
  *
- *  This runtime can be acccesibly in various language via
+ *  This runtime can be accessed in various language via
  *  TVM runtime PackedFunc API.
  */
 class TFLiteRuntime : public ModuleNode {
  public:
   /*!
-   * \brief Get member function to front-end
+   * \brief Get member function to front-end.
    * \param name The name of the function.
    * \param sptr_to_self The pointer to the module node.
    * \return The corresponding member function.
@@ -57,15 +58,11 @@ class TFLiteRuntime : public ModuleNode {
   /*!
    * \return The type key of the executor.
    */
-  const char* type_key() const final {
+  const char* type_key() const {
     return "TFLiteRuntime";
   }
 
   /*!
-   * \brief Update allocations for all tenssors. This is relatively expensive.
-   */
-  void AllocateTensors();
-  /*!
    * \brief Invoke the internal tflite interpreter and run the whole model in 
    * dependency order.
    */
@@ -100,8 +97,9 @@ class TFLiteRuntime : public ModuleNode {
    */
   NDArray GetOutput(int index) const;
 
- private:
+  // TFLite interpreter
   std::unique_ptr<tflite::Interpreter> interpreter_;
+  // TVM context
   TVMContext ctx_;
 };
 
diff --git a/tests/python/contrib/test_edgetpu_runtime.py b/tests/python/contrib/test_edgetpu_runtime.py
new file mode 100644 (file)
index 0000000..a5d9e34
--- /dev/null
@@ -0,0 +1,83 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import os
+import tvm
+import numpy as np
+from tvm import rpc
+from tvm.contrib import util, tflite_runtime
+# import tflite_runtime.interpreter as tflite
+
+
+def skipped_test_tflite_runtime():
+
+    def get_tflite_model_path(target_edgetpu):
+        # Return a path to the model
+        edgetpu_path = os.getenv('EDGETPU_PATH', "/home/mendel/edgetpu")
+        # Obtain mobilenet model from the edgetpu repo path
+        if target_edgetpu:
+            model_path = os.path.join(edgetpu_path, "test_data/mobilenet_v1_1.0_224_quant_edgetpu.tflite")
+        else:
+            model_path = os.path.join(edgetpu_path, "test_data/mobilenet_v1_1.0_224_quant.tflite")
+        return model_path
+
+    def init_interpreter(model_path, target_edgetpu):
+        # Initialize interpreter
+        if target_edgetpu:
+            edgetpu_path = os.getenv('EDGETPU_PATH', "/home/mendel/edgetpu")
+            libedgetpu = os.path.join(edgetpu_path, "libedgetpu/direct/aarch64/libedgetpu.so.1")
+            interpreter = tflite.Interpreter(
+                    model_path=model_path,
+                    experimental_delegates=[tflite.load_delegate(libedgetpu)])
+        else:
+            interpreter = tflite.Interpreter(model_path=model_path)
+        return interpreter
+
+    def check_remote(target_edgetpu=False):
+        tflite_model_path = get_tflite_model_path(target_edgetpu)
+
+        # inference via tflite interpreter python apis
+        interpreter = init_interpreter(tflite_model_path, target_edgetpu)
+        interpreter.allocate_tensors()
+        input_details = interpreter.get_input_details()
+        output_details = interpreter.get_output_details()
+
+        input_shape = input_details[0]['shape']
+        tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.uint8)
+        interpreter.set_tensor(input_details[0]['index'], tflite_input)
+        interpreter.invoke()
+        tflite_output = interpreter.get_tensor(output_details[0]['index'])
+
+        # inference via remote tvm tflite runtime
+        server = rpc.Server("localhost")
+        remote = rpc.connect(server.host, server.port)
+        ctx = remote.cpu(0)
+
+        with open(tflite_model_path, 'rb') as model_fin:
+            runtime = tflite_runtime.create(model_fin.read(), ctx)
+            runtime.set_input(0, tvm.nd.array(tflite_input, ctx))
+            runtime.invoke()
+            out = runtime.get_output(0)
+            np.testing.assert_equal(out.asnumpy(), tflite_output)
+
+    # Target CPU on coral board
+    check_remote()
+    # Target EdgeTPU on coral board
+    check_remote(target_edgetpu=True)
+
+if __name__ == "__main__":
+    # skipped_test_tflite_runtime()
+    pass
index e8bc663..9d396be 100644 (file)
@@ -36,16 +36,14 @@ def skipped_test_tflite_runtime():
         return tflite_model
 
 
-    def check_verify():
+    def check_local():
         tflite_fname = "model.tflite"
         tflite_model = create_tflite_model()
         temp = util.tempdir()
         tflite_model_path = temp.relpath(tflite_fname)
-        print(tflite_model_path)
         open(tflite_model_path, 'wb').write(tflite_model)
 
         # inference via tflite interpreter python apis
-        print('interpreter')
         interpreter = tflite.Interpreter(model_path=tflite_model_path)
         interpreter.allocate_tensors()
         input_details = interpreter.get_input_details()
@@ -57,11 +55,9 @@ def skipped_test_tflite_runtime():
         interpreter.invoke()
         tflite_output = interpreter.get_tensor(output_details[0]['index'])
         
-        print('tvm tflite runtime')
         # inference via tvm tflite runtime
         with open(tflite_model_path, 'rb') as model_fin:
             runtime = tflite_runtime.create(model_fin.read(), tvm.cpu(0))
-            runtime.allocate_tensors()
             runtime.set_input(0, tvm.nd.array(tflite_input))
             runtime.invoke()
             out = runtime.get_output(0)
@@ -95,14 +91,12 @@ def skipped_test_tflite_runtime():
 
         with open(tflite_model_path, 'rb') as model_fin:
             runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0))
-            runtime.allocate_tensors()
             runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0)))
             runtime.invoke()
             out = runtime.get_output(0)
             np.testing.assert_equal(out.asnumpy(), tflite_output)
 
-
-    check_verify()
+    check_local()
     check_remote()
 
 if __name__ == "__main__":