[TVM][RUNTIME] A minimum example to generate external library wrappers for DSOModule...
authorZhi <5145158+zhiics@users.noreply.github.com>
Fri, 22 Nov 2019 23:31:50 +0000 (15:31 -0800)
committerTianqi Chen <tqchen@users.noreply.github.com>
Fri, 22 Nov 2019 23:31:50 +0000 (15:31 -0800)
CMakeLists.txt
cmake/config.cmake
include/tvm/runtime/module.h
python/tvm/module.py
src/codegen/source_module.cc
src/runtime/contrib/example_ext_runtime/example_ext_runtime.cc [new file with mode: 0644]
src/runtime/dso_module.cc
src/runtime/graph/graph_runtime.cc
tests/python/relay/test_external_runtime.py [new file with mode: 0644]

index c99fe0d..bf18ffc 100644 (file)
@@ -232,6 +232,12 @@ if(USE_VM_PROFILER)
   list(APPEND RUNTIME_SRCS ${RUNTIME_VM_PROFILER_SRCS})
 endif(USE_VM_PROFILER)
 
+if(USE_EXAMPLE_EXT_RUNTIME)
+  message(STATUS "Build with example external runtime...")
+  file(GLOB RUNTIME_EXAMPLE_EXTERNAL_SRCS src/runtime/contrib/example_ext_runtime/*.cc)
+  list(APPEND RUNTIME_SRCS ${RUNTIME_EXAMPLE_EXTERNAL_SRCS})
+endif(USE_EXAMPLE_EXT_RUNTIME)
+
 # Module rules
 include(cmake/modules/VTA.cmake)
 include(cmake/modules/CUDA.cmake)
index cf273bf..1ef956c 100644 (file)
@@ -181,3 +181,6 @@ set(USE_VTA_TSIM ON)
 
 # Whether to build VTA FPGA driver (device side only)
 set(USE_VTA_FPGA OFF)
+
+# Whether to build the example external runtime module
+set(USE_EXAMPLE_EXT_RUNTIME OFF)
index ff096ee..b63b9bb 100644 (file)
@@ -111,7 +111,7 @@ class Module : public ObjectRef {
  *
  * \endcode
  */
-class ModuleNode : public Object {
+class TVM_DLL ModuleNode : public Object {
  public:
   /*! \brief virtual destructor */
   virtual ~ModuleNode() {}
index 98a3592..2790227 100644 (file)
@@ -144,7 +144,12 @@ class Module(ModuleBase):
             else:
                 fcompile = _cc.create_shared
         if self.type_key == "c":
-            kwargs.update({'options': ["-I" + path for path in find_include_path()]})
+            options = []
+            if "options" in kwargs:
+                opts = kwargs["options"]
+                options = opts if isinstance(opts, (list, tuple)) else [opts]
+            opts = options + ["-I" + path for path in find_include_path()]
+            kwargs.update({'options': opts})
         fcompile(file_name, files, **kwargs)
 
     def time_evaluator(self, func_name, ctx, number=10, repeat=1, min_repeat_ms=0):
index adbe7ea..e23ce60 100644 (file)
@@ -185,5 +185,8 @@ runtime::Module DeviceSourceModuleCreate(
 
 TVM_REGISTER_GLOBAL("module.source_module_create")
 .set_body_typed(SourceModuleCreate);
+
+TVM_REGISTER_GLOBAL("module.csource_module_create")
+.set_body_typed(CSourceModuleCreate);
 }  // namespace codegen
 }  // namespace tvm
diff --git a/src/runtime/contrib/example_ext_runtime/example_ext_runtime.cc b/src/runtime/contrib/example_ext_runtime/example_ext_runtime.cc
new file mode 100644 (file)
index 0000000..ef6fc87
--- /dev/null
@@ -0,0 +1,344 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file external_runtime_test.cc
+ * \brief Test an example runtime module to interpreting a json string.
+ *
+ * This is an exmaple runtime employed to show how we can interprete and execute
+ * a json string that represents a simple computational (sub)graph. Users will
+ * mainly need to implement four functions as follows:
+ *  - GetFunction. It is used to get the packed function from the json runtime
+ * module using a provided function name. This function returns a PackedFunc
+ * that can be directly invoked by feeding it with parameters.
+ *  - SaveToBinary. This function is used to achieve the serialization purpose.
+ * The emitted binary stream can be directly saved to disk so that users can
+ * load then back when needed.
+ *  - LoadFromBinary. This function uses binary stream to load the json that
+ * saved by SaveToBinary which essentially performs deserialization.
+ */
+#include <dmlc/logging.h>
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/memory.h>
+#include <tvm/runtime/module.h>
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/object.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+
+#include <fstream>
+#include <cmath>
+#include <map>
+#include <sstream>
+#include <string>
+#include <vector>
+
+namespace tvm {
+namespace runtime {
+
+// A simple JSON node that contains multiple inputs and a single output.
+struct NodeEntry {
+  int id;
+  int output;
+  std::vector<int> inputs;
+};
+
+/*!
+ * \brief The following 6 functions are examples for demonstration. Users need
+ * to provide their own API when they use the external library. The ones that
+ * accecpt TVMValue are wrappers used to bridge the PackedFunc and user-defined
+ * kernels.
+ */
+void Add_(float* a, int len_a, float* b, int len_b, float* c) {
+  for (int i = 0; i < len_a * len_b; i++) {
+    c[i] = a[i] + b[i];
+  }
+}
+
+int Add(TVMValue* value, int* type_code, int nargs) {
+  CHECK_EQ(nargs, 3U) << "Expect 3 args, but get " << nargs << "\n";
+  DLTensor* arg0 = static_cast<DLTensor*>(value[0].v_handle);
+  DLTensor* arg1 = static_cast<DLTensor*>(value[1].v_handle);
+  DLTensor* out = static_cast<DLTensor*>(value[2].v_handle);
+  Add_(static_cast<float*>(arg0->data), arg0->shape[0],
+       static_cast<float*>(arg1->data), arg1->shape[0],
+       static_cast<float*>(out->data));
+  return 0;
+}
+
+void Sub_(float* a, int len_a, float* b, int len_b, float* c) {
+  for (int i = 0; i < len_a * len_b; i++) {
+    c[i] = a[i] - b[i];
+  }
+}
+
+int Sub(TVMValue* value, int* type_code, int nargs) {
+  CHECK_EQ(nargs, 3U) << "Expect 3 args, but get " << nargs << "\n";
+  DLTensor* arg0 = static_cast<DLTensor*>(value[0].v_handle);
+  DLTensor* arg1 = static_cast<DLTensor*>(value[1].v_handle);
+  DLTensor* out = static_cast<DLTensor*>(value[2].v_handle);
+  Sub_(static_cast<float*>(arg0->data), arg0->shape[0],
+       static_cast<float*>(arg1->data), arg1->shape[0],
+       static_cast<float*>(out->data));
+  return 0;
+}
+
+void Mul_(float* a, int len_a, float* b, int len_b, float* c) {
+  for (int i = 0; i < len_a * len_b; i++) {
+    c[i] = a[i] * b[i];
+  }
+}
+
+int Mul(TVMValue* value, int* type_code, int nargs) {
+  CHECK_EQ(nargs, 3U) << "Expect 3 args, but get " << nargs << "\n";
+  DLTensor* arg0 = static_cast<DLTensor*>(value[0].v_handle);
+  DLTensor* arg1 = static_cast<DLTensor*>(value[1].v_handle);
+  DLTensor* out = static_cast<DLTensor*>(value[2].v_handle);
+  Mul_(static_cast<float*>(arg0->data), arg0->shape[0],
+       static_cast<float*>(arg1->data), arg1->shape[0],
+       static_cast<float*>(out->data));
+  return 0;
+}
+
+/*!
+ * \brief The example json runtime module. Here we define a simple format for
+ * the computational graph using json for demonstration purpose. Users should
+ * customize their own format.
+ */
+class ExampleJsonModule : public ModuleNode {
+ public:
+  explicit ExampleJsonModule(std::string graph_json) {
+    this->graph_json_ = graph_json;
+    ParseJson(this->graph_json_);
+  }
+
+  /*!
+   * \brief Get a PackedFunc from the example json module.
+   *
+   * \param name the name of the function.
+   * \param sptr_to_self The ObjectPtr that points to this module node.
+   *
+   * \return The function pointer when it is found, otherwise, PackedFunc(nullptr).
+   */
+  PackedFunc GetFunction(const std::string& name,
+                         const ObjectPtr<Object>& sptr_to_self) final {
+    if (this->graph_.find(name) != this->graph_.end()) {
+      this->curr_subgraph_ = name;
+      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+        for (auto i = 0; i < args.size(); ++i) {
+          CHECK(args[i].type_code() == kNDArrayContainer || args[i].type_code() == kArrayHandle)
+              << "Expect NDArray or DLTensor as inputs"
+              << "\n";
+          if (args[i].type_code() == kArrayHandle) {
+            DLTensor* arg = args[i];
+            this->data_entry_[i].CopyFrom(arg);
+          } else {
+            NDArray arg = args[i];
+            this->data_entry_[i].CopyFrom(arg);
+          }
+        }
+        for (const auto& it : this->graph_[this->curr_subgraph_]) {
+          this->Run(it.id, it.inputs, it.output);
+        }
+        CHECK_GT(graph_.count(this->curr_subgraph_), 0U);
+        auto out_idx = graph_[this->curr_subgraph_].back().output;
+        if (args[args.size() - 1].type_code() == kArrayHandle) {
+          DLTensor* arg = args[args.size() - 1];
+          this->data_entry_[out_idx].CopyTo(arg);
+        } else {
+          NDArray arg = args[args.size() - 1];
+          this->data_entry_[out_idx].CopyTo(arg);
+        }
+        *rv = data_entry_.back();
+      });
+    } else {
+      LOG(FATAL) << "Unkown runtime type: " << name << "\n";
+      return PackedFunc();
+    }
+  }
+
+  /*!
+   * \brief Execute a function with provided arguments. The output will be
+   * packed to the last argument according to TVM's calling convention.
+   *
+   * \param id The id of the function.
+   * \param inputs The input indices that indicate where the data should be
+   * fetched in the data entry pool.
+   * \param output The output index.
+   */
+  void Run(int id, const std::vector<int>& inputs, int output) {
+    std::vector<int> args(inputs.begin(), inputs.end());
+    args.push_back(output);
+    std::vector<TVMValue> values(args.size());
+    std::vector<int> type_codes(args.size());
+    TVMArgsSetter setter(values.data(), type_codes.data());
+
+    if (op_id_[id] == "add" || op_id_[id] == "sub" || op_id_[id] == "mul") {
+      for (size_t i = 0; i < args.size(); i++) {
+        setter(i, data_entry_[args[i]]);
+      }
+    }
+
+    if (op_id_[id] == "add") {
+      Add(values.data(), type_codes.data(), args.size());
+    } else if (op_id_[id] == "sub") {
+      Sub(values.data(), type_codes.data(), args.size());
+    } else if (op_id_[id] == "mul") {
+      Mul(values.data(), type_codes.data(), args.size());
+    } else {
+      LOG(FATAL) << "Unknown op: " << op_id_[id] << "\n";
+    }
+  }
+
+  const char* type_key() const { return "examplejson"; }
+
+  /*!
+   * \brief Save the json runtime to a binary stream, which can then be
+   * serialized to disk.
+   *
+   * \param stream. The stream to save the binary.
+   */
+  void SaveToBinary(dmlc::Stream* stream) final {
+      stream->Write(this->graph_json_);
+  }
+
+  /*!
+   * \brief Parse the example json string.
+   *
+   * \param json. The json string that represents a simple computational graph.
+   *
+   * \Note this is a very simple json that only serves for demostration purpose.
+   * Users usually have their own format and they can serialize it using the
+   * SaveToBinary method and deserialize it using LoadFromFile.
+   */
+  void ParseJson(const std::string& json) {
+    std::string line;
+    std::string curr_subgraph;
+    std::stringstream ss(json);
+
+    while (std::getline(ss, line, '\n')) {
+      std::stringstream ss2(line);
+      std::string token;
+      int id = 0;
+
+      ss2 >> token;
+      if (token.find("json_rt_") != std::string::npos) {
+        curr_subgraph = token;
+        continue;
+      }
+
+      ss2 >> id;
+      if (op_id_.size() <= static_cast<size_t>(id)) {
+        op_id_.resize(id + 1);
+        data_entry_.resize(id + 1);
+      }
+
+      int64_t total_elements = 1;
+      std::vector<int64_t> shape;
+      if (token == "input") {
+        int64_t size = 0;
+        while (ss2 >> size) {
+          total_elements *= size;
+          shape.push_back(size);
+        }
+      } else {
+        op_id_[id] = token;
+        bool shape_data = false;
+        NodeEntry entry;
+        while (ss2 >> token) {
+          if (token == "shape:") {
+            shape_data = true;
+          } else if (shape_data) {
+            total_elements *= std::stoll(token);
+            shape.push_back(std::stoll(token));
+          } else if (token != "inputs:") {
+            entry.inputs.push_back(std::stoi(token));
+          }
+        }
+        entry.id = id;
+        entry.output = id;
+        graph_[curr_subgraph].push_back(entry);
+      }
+      DLContext ctx;
+      ctx.device_type = static_cast<DLDeviceType>(1);
+      ctx.device_id = 0;
+      data_entry_[id] = NDArray::Empty(shape, DLDataType{kDLFloat, 32, 1}, ctx);
+    }
+  }
+
+  /*!
+   * \brief Create a module from a file path of a serialized graph.
+   *
+   * \param path The file path contains a computational graph representation.
+   *
+   * \return The created json module.
+   */
+  static Module Create(const std::string& path) {
+    std::ifstream filep;
+    filep.open(path, std::ios::in);
+    std::string graph_json;
+    std::string line;
+    while (std::getline(filep, line)) {
+      graph_json += line;
+      graph_json += "\n";
+    }
+    filep.close();
+    auto n = tvm::runtime::make_object<ExampleJsonModule>(graph_json);
+    return Module(n);
+  }
+
+  /*!
+   * \brief Load a json module from stream.
+   *
+   * \param strm The binary stream to load json.
+   *
+   * \return The created json module.
+   */
+  static Module LoadFromBinary(void* strm) {
+    dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
+    std::string graph_json;
+    stream->Read(&graph_json);
+    auto n = tvm::runtime::make_object<ExampleJsonModule>(graph_json);
+    return Module(n);
+  }
+
+ private:
+  /* \brief The json string that represents a computational graph. */
+  std::string graph_json_;
+  /* \brief The subgraph that being processed. */
+  std::string curr_subgraph_;
+  /*! \brief A simple graph from subgraph id to node entries. */
+  std::map<std::string, std::vector<NodeEntry> > graph_;
+  /* \brief A simple pool to contain the tensor for each node in the graph. */
+  std::vector<NDArray> data_entry_;
+  /* \brief A mapping from node id to op name. */
+  std::vector<std::string> op_id_;
+};
+
+TVM_REGISTER_GLOBAL("module.loadfile_examplejson")
+.set_body([](TVMArgs args, TVMRetValue* rv) {
+  *rv = ExampleJsonModule::Create(args[0]);
+});
+
+TVM_REGISTER_GLOBAL("module.loadbinary_examplejson")
+.set_body_typed(ExampleJsonModule::LoadFromBinary);
+
+}  // namespace runtime
+}  // namespace tvm
+
index abbbe12..4e18957 100644 (file)
@@ -18,7 +18,7 @@
  */
 
 /*!
- * \file dso_dll_module.cc
+ * \file dso_module.cc
  * \brief Module to load from dynamic shared library.
  */
 #include <tvm/runtime/module.h>
index 9ad10c1..06e5fef 100644 (file)
@@ -396,7 +396,7 @@ std::pair<std::function<void()>, std::shared_ptr<GraphRuntime::OpArgs> > GraphRu
 
   // Get compiled function from the module that contains both host and device
   // code.
-  tvm::runtime::PackedFunc pf = module_.GetFunction(param.func_name, false);
+  tvm::runtime::PackedFunc pf = module_.GetFunction(param.func_name, true);
   CHECK(pf != nullptr) << "no such function in module: " << param.func_name;
 
   auto fexec = [arg_ptr, pf]() {
diff --git a/tests/python/relay/test_external_runtime.py b/tests/python/relay/test_external_runtime.py
new file mode 100644 (file)
index 0000000..887d9dc
--- /dev/null
@@ -0,0 +1,558 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from shutil import which
+import json
+import pytest
+import sys
+import numpy as np
+
+import tvm
+from tvm import relay
+from tvm import module as _tvm_module
+from tvm.contrib import util
+
+tmp_path = util.tempdir()
+
+
+def generate_csource_module():
+    """Mock the codegen with an external library (e.g., CBLAS/cuDNN)"""
+
+    code = r'''
+    #include <tvm/runtime/c_runtime_api.h>
+    #include <dlpack/dlpack.h>
+    #include <cstdint>
+    #include <cstring>
+    #include <iostream>
+
+    #define GCC_BINARY_OP_1D(p_ID_, p_OP_, p_DIM1_)           \
+      extern "C" void p_ID_(float* a, float* b, float* out) { \
+        for (int64_t i = 0; i < p_DIM1_; ++i) {               \
+          out[i] = a[i] p_OP_ b[i];                           \
+        }                                                     \
+      }
+
+    #define GCC_BINARY_OP_2D(p_ID_, p_OP_, p_DIM1_, p_DIM2_)  \
+      extern "C" void p_ID_(float* a, float* b, float* out) { \
+        for (int64_t i = 0; i < p_DIM1_; ++i) {               \
+          for (int64_t j = 0; j < p_DIM2_; ++j) {             \
+            int64_t k = i * p_DIM2_ + j;                      \
+            out[k] = a[k] p_OP_ b[k];                         \
+          }                                                   \
+        }                                                     \
+      }
+    GCC_BINARY_OP_2D(gcc_1_0, *, 10, 10);
+    GCC_BINARY_OP_2D(gcc_1_1, -, 10, 10);
+    GCC_BINARY_OP_2D(gcc_1_2, +, 10, 10);
+
+    extern "C" void gcc_1_(float* gcc_input4, float* gcc_input5,
+                           float* gcc_input6, float* gcc_input7, float* out) {
+      float* buf_0 = (float*)malloc(4 * 100);
+      float* buf_1 = (float*)malloc(4 * 100);
+      gcc_1_2(gcc_input4, gcc_input5, buf_0);
+      gcc_1_1(buf_0, gcc_input6, buf_1);
+      gcc_1_0(buf_1, gcc_input7, out);
+      free(buf_0);
+      free(buf_1);
+    }
+
+    extern "C" int json_rt_1(TVMValue* value, int* type_code, int nargs) {
+      if (nargs != 5) {
+        printf("Expect 5 args, but get %d", nargs);
+        return 1;
+      }
+      DLTensor* arg0 = static_cast<DLTensor*>(value[0].v_handle);
+      DLTensor* arg1 = static_cast<DLTensor*>(value[1].v_handle);
+      DLTensor* arg2 = static_cast<DLTensor*>(value[2].v_handle);
+      DLTensor* arg3 = static_cast<DLTensor*>(value[3].v_handle);
+      DLTensor* out = static_cast<DLTensor*>(value[4].v_handle);
+      gcc_1_(static_cast<float*>(arg0->data), static_cast<float*>(arg1->data),
+             static_cast<float*>(arg2->data), static_cast<float*>(arg3->data),
+             static_cast<float*>(out->data));
+      return 0;
+    }
+
+    GCC_BINARY_OP_2D(gcc_0_0, *, 10, 10);
+    GCC_BINARY_OP_2D(gcc_0_1, -, 10, 10);
+    GCC_BINARY_OP_2D(gcc_0_2, +, 10, 10);
+
+    extern "C" void gcc_0_(float* gcc_input0, float* gcc_input1,
+                           float* gcc_input2, float* gcc_input3, float* out) {
+      float* buf_0 = (float*)malloc(4 * 100);
+      float* buf_1 = (float*)malloc(4 * 100);
+      gcc_0_2(gcc_input0, gcc_input1, buf_0);
+      gcc_0_1(buf_0, gcc_input2, buf_1);
+      gcc_0_0(buf_1, gcc_input3, out);
+      free(buf_0);
+      free(buf_1);
+    }
+
+    extern "C" int json_rt_0(TVMValue* value, int* type_code, int nargs) {
+      if (nargs != 5) {
+        printf("Expect 5 args, but get %d", nargs);
+        return 1;
+      }
+      DLTensor* arg0 = static_cast<DLTensor*>(value[0].v_handle);
+      DLTensor* arg1 = static_cast<DLTensor*>(value[1].v_handle);
+      DLTensor* arg2 = static_cast<DLTensor*>(value[2].v_handle);
+      DLTensor* arg3 = static_cast<DLTensor*>(value[3].v_handle);
+      DLTensor* out = static_cast<DLTensor*>(value[4].v_handle);
+      gcc_0_(static_cast<float*>(arg0->data), static_cast<float*>(arg1->data),
+             static_cast<float*>(arg2->data), static_cast<float*>(arg3->data),
+             static_cast<float*>(out->data));
+      return 0;
+    }
+    '''
+    csource_module = _tvm_module.csource_module_create(code, "cc")
+    return csource_module
+
+
+def generate_engine_module():
+    """
+    Mock the codegen of an external backend with its own runtime engine
+    (e.g., MKL-DNN/TensorRT)
+    """
+
+    code = r'''
+    #include <tvm/runtime/c_runtime_api.h>
+    #include <dlpack/dlpack.h>
+    #include "gcc_engine.h"
+
+    extern "C" void gcc_1_(float* gcc_input4, float* gcc_input5,
+            float* gcc_input6, float* gcc_input7, float* out) {
+            
+        std::string graph =
+            "add_2d,10,10\n"
+            "sub_2d,10,10\n"
+            "mul_2d,10,10\n";
+
+        Engine engine;
+        engine.run(graph, {gcc_input4, gcc_input5, gcc_input6, gcc_input7}, out);
+    }
+
+
+    extern "C" int json_rt_1(TVMValue* value, int* type_code, int nargs) {
+        if (nargs != 5) {
+            printf("Expect 5 args, but get %d", nargs);
+            return 1;
+        }
+        DLTensor* arg0 = static_cast<DLTensor*>(value[0].v_handle);
+        DLTensor* arg1 = static_cast<DLTensor*>(value[1].v_handle);
+        DLTensor* arg2 = static_cast<DLTensor*>(value[2].v_handle);
+        DLTensor* arg3 = static_cast<DLTensor*>(value[3].v_handle);
+        DLTensor* out = static_cast<DLTensor*>(value[4].v_handle);
+        gcc_1_(static_cast<float*>(arg0->data), static_cast<float*>(arg1->data),
+                static_cast<float*>(arg2->data), static_cast<float*>(arg3->data),
+                static_cast<float*>(out->data));
+        return 0;
+    }
+
+    extern "C" void gcc_0_(float* gcc_input0, float* gcc_input1,
+            float* gcc_input2, float* gcc_input3, float* out) {
+            
+        std::string graph =
+            "add_2d,10,10\n"
+            "sub_2d,10,10\n"
+            "mul_2d,10,10\n";
+
+        Engine engine;
+        engine.run(graph, {gcc_input0, gcc_input1, gcc_input2, gcc_input3}, out);
+
+    }
+
+    extern "C" int json_rt_0(TVMValue* value, int* type_code, int nargs) {
+        if (nargs != 5) {
+            printf("Expect 5 args, but get %d", nargs);
+            return 1;
+        }
+        DLTensor* arg0 = static_cast<DLTensor*>(value[0].v_handle);
+        DLTensor* arg1 = static_cast<DLTensor*>(value[1].v_handle);
+        DLTensor* arg2 = static_cast<DLTensor*>(value[2].v_handle);
+        DLTensor* arg3 = static_cast<DLTensor*>(value[3].v_handle);
+        DLTensor* out = static_cast<DLTensor*>(value[4].v_handle);
+        gcc_0_(static_cast<float*>(arg0->data), static_cast<float*>(arg1->data),
+                static_cast<float*>(arg2->data), static_cast<float*>(arg3->data),
+                static_cast<float*>(out->data));
+        return 0;
+    }
+    '''
+
+    gen_gcc_engine()
+    csource_module = _tvm_module.csource_module_create(code, "cc")
+    return csource_module
+
+
+def gen_gcc_engine():
+    """An example of external backend runtime engine. This is supposed to be provided
+      by third-party vendors and included when building the generated external kernel code.
+    """
+
+    code = r'''
+    #ifndef _GCC_ENGINE_H_
+    #define _GCC_ENGINE_H_
+    #include <cstdint>
+    #include <string>
+    #include <sstream>
+    #include <vector>
+
+    #define GCC_BINARY_OP_2D(p_ID_, p_OP_)  \
+      void p_ID_(int64_t dim1, int64_t dim2, float* a, float* b, float* out) { \
+        for (int64_t i = 0; i < dim1; ++i) {                                   \
+          for (int64_t j = 0; j < dim2; ++j) {                                 \
+            int64_t k = i * dim2 + j;                                          \
+            out[k] = a[k] p_OP_ b[k];                                          \
+          }                                                                    \
+        }                                                                      \
+      }
+    GCC_BINARY_OP_2D(add_2d, +);
+    GCC_BINARY_OP_2D(sub_2d, -);
+    GCC_BINARY_OP_2D(mul_2d, *);
+
+    struct Layer {
+        void (*op)(int64_t, int64_t, float*, float*, float*);
+        std::vector<int64_t> shapes;
+        std::vector<float*> args;
+    };
+
+    class Engine {
+    public:
+        float* alloc_buffer(int64_t size) {
+            float* buf = (float*)malloc(sizeof(float) * size);
+            buffers.push_back(buf);
+            return buf;
+        }
+        void add(std::string op, int64_t dim1, int64_t dim2, float* in1, float* in2, float* out) {
+            Layer layer;
+            layer.shapes.push_back(dim1);
+            layer.shapes.push_back(dim2);
+            layer.args.push_back(in1);
+            layer.args.push_back(in2);
+            layer.args.push_back(out);
+
+            if (op == "add_2d")
+                layer.op = &add_2d;
+            else if (op == "sub_2d")
+                layer.op = &sub_2d;
+            else if (op == "mul_2d")
+                layer.op = &mul_2d;
+            net.push_back(layer);
+            return ;
+        }
+
+        void run(std::string graph, std::vector<float*> args, float* out) {
+            std::stringstream ss(graph);
+            std::string line;
+            int layer_idx = 0;
+            int arg_idx = 0;
+            float* buf = nullptr;
+
+            while (std::getline(ss, line, '\n')) {
+                std::stringstream ss2(line);
+                std::string token;
+                std::vector<std::string> attrs;
+                while (std::getline(ss2, token, ',')) {
+                    attrs.push_back(token);
+                }
+                int64_t dim1 = stoll(attrs[1]);
+                int64_t dim2 = stoll(attrs[2]);
+                auto out_buf = this->alloc_buffer(dim1 * dim2);
+
+                if (layer_idx == 0) {
+                    this->add(attrs[0], dim1, dim2, args[0], args[1], out_buf);
+                    buf = out_buf;
+                    arg_idx = 2;
+                }
+                else {
+                    this->add(attrs[0], dim1, dim2, buf, args[arg_idx], out_buf);
+                    buf = out_buf;
+                    arg_idx++;
+                }
+                layer_idx++;
+            }
+            this->net.back().args.back() = out;
+
+            for (auto layer : net) {
+                (*layer.op)(layer.shapes[0], layer.shapes[1], layer.args[0], layer.args[1], layer.args[2]);
+            }
+        }
+        ~Engine() {
+            for (auto buf : buffers) {
+                free(buf);
+            }
+        }
+    private:
+        std::vector<Layer> net;
+        std::vector<float*> buffers;
+    };
+
+    #endif
+    '''
+    header_file = tmp_path.relpath("gcc_engine.h")
+    with open(header_file, 'w') as f:
+        f.write(code)
+
+
+def get_synthetic_lib():
+    x = relay.var('x', shape=(10, 10))
+    w0 = relay.var('w0', shape=(10, 10))
+    w1 = relay.var('w1', shape=(10, 10))
+    w2 = relay.var('w2', shape=(10, 10))
+    w3 = relay.var('w3', shape=(10, 10))
+    w4 = relay.var('w4', shape=(10, 10))
+    w5 = relay.var('w5', shape=(10, 10))
+    w6 = relay.var('w6', shape=(10, 10))
+    w7 = relay.var('w7', shape=(10, 10))
+
+    # subgraph0
+    gcc_input0 = relay.var('gcc_input0', shape=(10, 10))
+    gcc_input1 = relay.var('gcc_input1', shape=(10, 10))
+    gcc_input2 = relay.var('gcc_input2', shape=(10, 10))
+    gcc_input3 = relay.var('gcc_input3', shape=(10, 10))
+    subgraph0 = relay.Function([gcc_input0, gcc_input1, gcc_input2,
+                                gcc_input3], relay.copy(gcc_input0))
+    subgraph0 = subgraph0.set_attribute(
+        "Primitive", tvm.expr.IntImm("int32", 1))
+
+    # Call subgraph0
+    subgraph0_ret = relay.Call(subgraph0, [x, w0, w1, w2])
+
+    # subgraph1
+    gcc_input4 = relay.var('gcc_input4', shape=(10, 10))
+    gcc_input5 = relay.var('gcc_input5', shape=(10, 10))
+    gcc_input6 = relay.var('gcc_input6', shape=(10, 10))
+    gcc_input7 = relay.var('gcc_input7', shape=(10, 10))
+    subgraph1 = relay.Function([gcc_input4, gcc_input5, gcc_input6,
+                                gcc_input7], relay.copy(gcc_input4))
+    subgraph1 = subgraph1.set_attribute(
+        "Primitive", tvm.expr.IntImm("int32", 1))
+
+    # Call subgraph1
+    subgraph1_ret = relay.Call(subgraph1, [x, w3, w4, w5])
+
+    # Other ops that will be executed on TVM.
+    add2 = relay.add(x, w6)
+    sub2 = relay.subtract(add2, w7)
+    ret = relay.concatenate((subgraph0_ret, subgraph1_ret, sub2), 0)
+    func = relay.Function([x, w0, w1, w2, w3, w4, w5, w6, w7], ret)
+    mod = relay.Module.from_expr(func)
+    _, lib, _ = relay.build(mod, "llvm")
+    return lib
+
+def get_whole_graph_json():
+    nodex = {"op": "null", "name": "x", "inputs": []}
+    node0 = {"op": "null", "name": "w0", "inputs": []}
+    node1 = {"op": "null", "name": "w1", "inputs": []}
+    node2 = {"op": "null", "name": "w2", "inputs": []}
+    node3 = {"op": "null", "name": "w3", "inputs": []}
+    node4 = {"op": "null", "name": "w4", "inputs": []}
+    node5 = {"op": "null", "name": "w5", "inputs": []}
+    node6 = {"op": "null", "name": "w6", "inputs": []}
+    node7 = {"op": "null", "name": "w7", "inputs": []}
+
+    subgraph0 = {
+        "op": "tvm_op",
+        "name": "json_rt_0",
+        "attrs": {
+            "num_outputs": "1",
+            "num_inputs": "4",
+            "func_name": "json_rt_0",
+            "flatten_data": "0"
+        },
+        "inputs": [
+            [0, 0, 0],
+            [1, 0, 0],
+            [2, 0, 0],
+            [3, 0, 0],
+        ]
+    }
+    subgraph1 = {
+        "op": "tvm_op",
+        "name": "json_rt_1",
+        "attrs": {
+            "num_outputs": "1",
+            "num_inputs": "4",
+            "func_name": "json_rt_1",
+            "flatten_data": "0"
+        },
+        "inputs": [
+            [0, 0, 0],
+            [4, 0, 0],
+            [5, 0, 0],
+            [6, 0, 0],
+        ]
+    }
+
+    fused_op = {
+        "op": "tvm_op",
+        "name": "fused_add_subtract_concatenate",
+        "attrs": {
+            "num_outputs": "1",
+            "num_inputs": "5",
+            "func_name": "fused_add_subtract_concatenate",
+            "flatten_data": "0"
+        },
+        "inputs": [
+            [9, 0, 0],
+            [10, 0, 0],
+            [0, 0, 0],
+            [7, 0, 0],
+            [8, 0, 0]
+        ]
+    }
+    nodes = [nodex, node0, node1, node2, node3, node4,
+             node5, node6, node7, subgraph0, subgraph1, fused_op]
+    arg_nodes = [0, 1, 2, 3, 4, 5, 6, 7, 8]
+    heads = [[11, 0, 0]]
+    node_row_ptr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
+    storage_id = ["list_int", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]]
+
+    shape = ["list_shape", [
+        [10, 10], [10, 10], [10, 10], [10, 10], [10, 10], [10, 10],
+        [10, 10], [10, 10], [10, 10], [10, 10], [10, 10], [30, 10]]]
+
+    dltype = ["list_str", [
+        "float32", "float32", "float32", "float32", "float32", "float32",
+        "float32", "float32", "float32", "float32", "float32", "float32"]]
+
+    attrs = {
+        "shape": shape,
+        "dltype": dltype,
+        "storage_id": storage_id,
+    }
+
+    graph = {"nodes": nodes,
+             "arg_nodes": arg_nodes,
+             "node_row_ptr": node_row_ptr,
+             "heads": heads,
+             "attrs": attrs}
+
+    return json.dumps(graph)
+
+
+def run_extern(label, get_extern_src, **kwargs):
+    if which("gcc") is None:
+        print("Skip test because gcc is not available.")
+
+    obj_name = "{}.o".format(label)
+    lib_name = "external_{}.so".format(label)
+
+    # Get Json and the compiled library.
+    graph_json = get_whole_graph_json()
+    lib = get_synthetic_lib()
+    lib.save(obj_name)
+
+    # library that contains external code.
+    csource_module = get_extern_src()
+    kwargs["options"] = [obj_name] + kwargs["options"]
+    lib_path = tmp_path.relpath(lib_name)
+    csource_module.export_library(lib_path, fcompile=False, **kwargs)
+    # load module for execution.
+    lib = tvm.module.load(lib_path)
+    mod = tvm.contrib.graph_runtime.create(graph_json, lib, tvm.cpu(0))
+
+    x_data = np.random.rand(10, 10).astype('float32')
+    mod.set_input("x", x_data)
+    w_data = []
+    for i in range(8):
+        data = np.random.rand(10, 10).astype('float32')
+        w_data.append(data)
+        var = "w" + str(i)
+        mod.set_input(var, data)
+    mod.run()
+    out = tvm.nd.empty((30, 10), ctx=tvm.cpu())
+    out = mod.get_output(0, out)
+    tvm.testing.assert_allclose(
+        out.asnumpy(),
+        np.concatenate((((x_data + w_data[0]) - w_data[1]) * w_data[2],
+                        ((x_data + w_data[3]) - w_data[4]) * w_data[5],
+                        x_data + w_data[6] - w_data[7]),
+                       axis=0))
+
+
+def test_dso_extern():
+    run_extern("lib", generate_csource_module, options=["-O2", "-std=c++11"])
+
+
+def test_engine_extern():
+    run_extern("engine",
+               generate_engine_module,
+               options=["-O2", "-std=c++11", "-I" + tmp_path.relpath("")])
+
+def test_json_extern():
+    if which("gcc") is None:
+        print("Skip test because gcc is not available.")
+
+    # Get subgraph Json.
+    subgraph_json = ("json_rt_0\n" +
+                     "input 0 10 10\n" +
+                     "input 1 10 10\n" +
+                     "input 2 10 10\n" +
+                     "input 3 10 10\n" +
+                     "add 4 inputs: 0 1 shape: 10 10\n" +
+                     "sub 5 inputs: 4 2 shape: 10 10\n" +
+                     "mul 6 inputs: 5 3 shape: 10 10\n" +
+                     "json_rt_1\n" +
+                     "input 0 10 10\n" +
+                     "input 1 10 10\n" +
+                     "input 2 10 10\n" +
+                     "input 3 10 10\n" +
+                     "add 4 inputs: 0 1 shape: 10 10\n" +
+                     "sub 5 inputs: 4 2 shape: 10 10\n" +
+                     "mul 6 inputs: 5 3 shape: 10 10")
+
+    subgraph_path = tmp_path.relpath('subgraph.examplejson')
+    with open(subgraph_path, 'w') as f:
+        f.write(subgraph_json)
+
+    # Get Json and module.
+    graph_json = get_whole_graph_json()
+
+
+    lib = get_synthetic_lib()
+    ext_lib = tvm.module.load(subgraph_path, "examplejson")
+    lib.import_module(ext_lib)
+    lib_name = 'external.so'
+    lib_path = tmp_path.relpath(lib_name)
+    lib.export_library(lib_path)
+
+    # load module for execution.
+    lib = tvm.module.load(lib_path)
+    mod = tvm.contrib.graph_runtime.create(graph_json, lib, tvm.cpu(0))
+
+    x_data = np.random.rand(10, 10).astype('float32')
+    mod.set_input("x", x_data)
+    w_data = []
+    for i in range(8):
+        data = np.random.rand(10, 10).astype('float32')
+        w_data.append(data)
+        var = "w" + str(i)
+        mod.set_input(var, data)
+
+    mod.run()
+    out = tvm.nd.empty((30, 10), ctx=tvm.cpu())
+    out = mod.get_output(0, out)
+    tvm.testing.assert_allclose(
+        out.asnumpy(),
+        np.concatenate((((x_data + w_data[0]) - w_data[1]) * w_data[2],
+                        ((x_data + w_data[3]) - w_data[4]) * w_data[5],
+                        x_data + w_data[6] - w_data[7]),
+                       axis=0))
+
+
+if __name__ == "__main__":
+    test_dso_extern()
+    test_engine_extern()
+    test_json_extern()