[RUNTIME] Introduce MetadataModule to separate code compilation/interpretation and...
authorZhi <5145158+zhiics@users.noreply.github.com>
Thu, 18 Jun 2020 22:18:29 +0000 (15:18 -0700)
committerGitHub <noreply@github.com>
Thu, 18 Jun 2020 22:18:29 +0000 (15:18 -0700)
19 files changed:
python/tvm/contrib/graph_runtime.py
python/tvm/runtime/vm.py
src/relay/backend/build_module.cc
src/relay/backend/compile_engine.cc
src/relay/backend/contrib/codegen_c/codegen.cc
src/relay/backend/contrib/codegen_c/codegen_c.h
src/relay/backend/contrib/dnnl/codegen.cc
src/relay/backend/graph_runtime_codegen.cc
src/relay/backend/utils.h
src/relay/backend/vm/compiler.cc
src/runtime/graph/graph_runtime.cc
src/runtime/meta_data.h
src/runtime/metadata_module.cc [new file with mode: 0644]
src/target/source/codegen_source_base.h
src/target/source/source_module.cc
tests/python/frontend/onnx/test_forward.py
tests/python/relay/test_external_codegen.py
tests/python/relay/test_external_runtime.py
tests/python/unittest/test_runtime_module_export.py

index 740d1c3..9b714a8 100644 (file)
@@ -162,7 +162,12 @@ class GraphModule(object):
             keys = list(params.keys())
             keys.sort(key=lambda x: -np.prod(params[x].shape))
             for k in keys:
-                self._get_input(k).copyfrom(params[k])
+                # TODO(zhiics) Skip the weights for submodule in a better way.
+                # We should use MetadataModule for initialization and remove
+                # params from set_input
+                val = self._get_input(k)
+                if val:
+                    self._get_input(k).copyfrom(params[k])
 
     def run(self, **input_dict):
         """Run forward execution of the graph
index 2643ff1..8a85051 100644 (file)
@@ -106,7 +106,7 @@ class Executable(object):
 
             import numpy as np
             import tvm
-from tvm import te
+            from tvm import te
             from tvm import relay
             # define a simple network.
             x = relay.var('x', shape=(10, 10))
@@ -309,12 +309,17 @@ class VirtualMachine(object):
             Named arguments to the function.
         """
         if kwargs:
+            # kwargs is a super set of the required function parameters. We
+            # only find the ones that are needed.
             func_params = self._exec.get_function_params(func_name)
             new_args = [None] * len(func_params)
-            assert len(args) + len(kwargs) == len(func_params)
+            cnt = 0
             for k in kwargs:
-                idx = func_params.index(k)
-                new_args[idx] = kwargs[k]
+                if k in func_params:
+                    idx = func_params.index(k)
+                    new_args[idx] = kwargs[k]
+                    cnt += 1
+            assert len(args) + cnt == len(func_params)
             idx = 0
             for i, arg in enumerate(new_args):
                 if arg is None:
index 27c55a9..34c3487 100644 (file)
@@ -455,8 +455,11 @@ class RelayBuildModule : public runtime::ModuleNode {
     }
 
     Array<tvm::runtime::Module> ext_mods = graph_codegen_->GetExternalModules();
-    // Import all external runtime modules.
-    for (const auto& it : ext_mods) ret_.mod.Import(it);
+    // TODO(zhiics) We should be able to completely switch to MetadataModule no
+    // matter whether there are external modules or not.
+    if (!ext_mods.empty()) {
+      ret_.mod = tvm::codegen::CreateMetadataModule(ret_.params, ret_.mod, ext_mods);
+    }
   }
 
  private:
index 3687b75..2aae854 100644 (file)
@@ -571,7 +571,8 @@ class CompileEngineImpl : public CompileEngineNode {
   }
 
   Array<tvm::runtime::Module> LowerExternalFunctions() {
-    std::unordered_map<std::string, IRModule> ext_mods;
+    Array<tvm::runtime::Module> ret;
+    std::unordered_map<std::string, std::string> cached_symbol;
     std::vector<CCacheKey> cached_ext_funcs;
     for (const auto& it : cache_) {
       auto src_func = it.first->source_func;
@@ -580,29 +581,31 @@ class CompileEngineImpl : public CompileEngineNode {
         auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
         CHECK(code_gen.defined()) << "No external codegen is set";
         std::string code_gen_name = code_gen.value();
-        if (ext_mods.find(code_gen_name) == ext_mods.end()) {
-          ext_mods[code_gen_name] = IRModule({}, {});
-        }
+        cached_ext_funcs.push_back(it.first);
+
         auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
         CHECK(symbol_name.defined()) << "No external symbol is set for:\n"
                                      << AsText(src_func, false);
-        auto gv = GlobalVar(symbol_name.value());
+
+        std::string sn = symbol_name.value();
+        if (cached_symbol.count(sn)) {
+          cached_symbol[sn] = code_gen_name;
+        } else {
+          CHECK_NE(sn, code_gen_name)
+              << "Found duplicated symbol: " << sn << " for: " << code_gen_name;
+        }
+
+        std::string ext_name = "relay.ext." + code_gen_name;
+        auto pf = tvm::runtime::Registry::Get(ext_name);
+        CHECK(pf) << "Failed to find the codegen tool for " << ext_name << "\n";
         // No need to keep compiler attribute at this point, functions have been
         // extracted for specific codegen.
         src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue<ObjectRef>());
-        ext_mods[code_gen_name]->Add(gv, src_func);
-        cached_ext_funcs.push_back(it.first);
-      }
-    }
+        runtime::Module ext_mod = (*pf)(src_func);
 
-    Array<tvm::runtime::Module> ret;
-    for (const auto& it : ext_mods) {
-      std::string ext_name = "relay.ext." + it.first;
-      auto pf = tvm::runtime::Registry::Get(ext_name);
-      CHECK(pf) << "Failed to find the codegen tool for " << ext_name << "\n";
-      runtime::Module ext_mod = (*pf)(it.second);
-      CHECK(ext_mod.defined()) << "No external runtime is generated.";
-      ret.push_back(ext_mod);
+        CHECK(ext_mod.defined()) << "No external runtime is generated.";
+        ret.push_back(ext_mod);
+      }
     }
 
     // No need to cache external functions as we collected them all to create
@@ -658,6 +661,7 @@ class CompileEngineImpl : public CompileEngineNode {
       CHECK(name_node.defined()) << "External function has not been attached a name yet.";
       cache_node->func_name = std::string(name_node.value());
       cache_node->target = tvm::target::ext_dev();
+      cache_node->funcs->Add(GlobalVar(cache_node->func_name), key->source_func);
       value->cached_func = CachedFunc(cache_node);
       return value;
     }
index 2968966..c7b5a8d 100644 (file)
@@ -25,6 +25,7 @@
 
 #include <fstream>
 #include <sstream>
+#include <string>
 
 #include "../../utils.h"
 #include "codegen_c.h"
@@ -76,43 +77,29 @@ class CodegenC : public MemoizedExprTranslator<std::vector<Output>>, public Code
   }
 
   std::vector<Output> VisitExpr_(const ConstantNode* cn) final {
-    // Note this is for demonstration purpose. ConstantNode doesn't necessarily
-    // belong to calls. We need to revisit this when tuples come into play.
-
     std::ostringstream decl_stream;
     std::ostringstream buf_stream;
 
     Output output;
-    output.name = "const_" + std::to_string(const_idx_++);
-
-    runtime::NDArray array = cn->data;
-    const auto& shape = array.Shape();
-
-    // Get the number of elements.
-    int64_t num_elems = 1;
-    for (auto i : shape) num_elems *= i;
-
+    // Get const: static_cast<float*>(gcc_0_consts[0]->data)
+    output.name = CreateDataReference(ext_func_id_, const_idx_);
     const auto* type_node = cn->checked_type().as<TensorTypeNode>();
     CHECK(type_node);
     const auto& dtype = GetDtypeString(type_node);
-    // Define a const buffer: float const_0[64] = {1.0, 2.0, ...};
-    //
-    // Technically, you may need: static float* const_0 = (float*)malloc(4 * 64)
-    // to avoid possible stack overflow.
-    buf_stream << dtype << " " << output.name << "[" << num_elems << "] = {";
-    if (dtype == "float") {
-      float* p_flt = static_cast<float*>(array->data);
-      for (int64_t i = 0; i < num_elems - 1; i++) buf_stream << p_flt[i] << ", ";
-      if (num_elems) buf_stream << p_flt[num_elems - 1];
-    } else if (dtype == "int") {
-      int* p_flt = static_cast<int*>(array->data);
-      for (int64_t i = 0; i < num_elems - 1; i++) buf_stream << p_flt[i] << ", ";
-      if (num_elems) buf_stream << p_flt[num_elems - 1];
-    } else {
-      LOG(FATAL) << "Only float and int are supported for now.";
+
+    // Generate the global variable for needed ndarrays
+    if (const_array_name_.empty()) {
+      const_array_name_ = CreateNDArrayPool(ext_func_id_);
+      std::string checker = CreateInitChecker(ext_func_id_);
+      ext_func_body_.insert(ext_func_body_.begin(), checker);
     }
-    buf_stream << "};";
-    ext_func_body.insert(ext_func_body.begin(), buf_stream.str());
+
+    CHECK(dtype == "float" || dtype == "int") << "Only float and int are supported for now.";
+    output.dtype = dtype;
+
+    std::string const_var_name = CreateConstVar(ext_func_id_, const_idx_);
+    const_vars_.push_back(const_var_name);
+    const_idx_++;
 
     return {output};
   }
@@ -175,7 +162,7 @@ class CodegenC : public MemoizedExprTranslator<std::vector<Output>>, public Code
     buf_decl_.push_back(buf_stream.str());
 
     decl_stream << ", " << out << ");";
-    ext_func_body.push_back(decl_stream.str());
+    ext_func_body_.push_back(decl_stream.str());
 
     // Update output buffer
     // Note C codegen only handles TensorType. Therefore, we don't flatten
@@ -198,7 +185,7 @@ class CodegenC : public MemoizedExprTranslator<std::vector<Output>>, public Code
     for (auto decl : func_decl_) {
       code_stream_ << decl << "\n";
     }
-    return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out);
+    return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body_, const_array_name_, out);
   }
 
  private:
@@ -213,16 +200,22 @@ class CodegenC : public MemoizedExprTranslator<std::vector<Output>>, public Code
   /*! \brief The arguments of a C compiler compatible function. */
   Array<Var> ext_func_args_;
   /*! \brief The statements of a C compiler compatible function. */
-  std::vector<std::string> ext_func_body;
+  std::vector<std::string> ext_func_body_;
+  /*! \brief The array declared to store the constant values. */
+  std::string const_array_name_;
   /*! \brief The declaration statements of a C compiler compatible function. */
   std::vector<std::string> func_decl_;
   /*! \brief The declaration statements of buffers. */
   std::vector<std::string> buf_decl_;
+  /*! \brief The variable name to constant mapping. */
+  Array<String> const_vars_;
+
+  friend class CSourceCodegen;
 };
 
 class CSourceCodegen : public CSourceModuleCodegenBase {
  public:
-  void GenCFunc(const Function& func) {
+  std::pair<std::string, Array<String>> GenCFunc(const Function& func) {
     CHECK(func.defined()) << "Input error: expect a Relay function.";
 
     // Record the external symbol for runtime lookup.
@@ -231,14 +224,19 @@ class CSourceCodegen : public CSourceModuleCodegenBase {
     CodegenC builder(sid);
     auto out = builder.VisitExpr(func->body);
     code_stream_ << builder.JIT(out);
+
+    return {sid, builder.const_vars_};
   }
 
   runtime::Module CreateCSourceModule(const ObjectRef& ref) override {
     // Create headers
     code_stream_ << "#include <cstring>\n";
+    code_stream_ << "#include <vector>\n";
     code_stream_ << "#include <tvm/runtime/c_runtime_api.h>\n";
+    code_stream_ << "#include <tvm/runtime/container.h>\n";
     code_stream_ << "#include <tvm/runtime/packed_func.h>\n";
     code_stream_ << "#include <dlpack/dlpack.h>\n";
+    code_stream_ << "using namespace tvm::runtime;\n";
 
     // Append some common macro for operator definition.
     const char* operator_macro = R"op_macro(
@@ -262,22 +260,17 @@ class CSourceCodegen : public CSourceModuleCodegenBase {
 
     code_stream_ << operator_macro << "\n\n";
 
-    if (ref->IsInstance<FunctionNode>()) {
-      GenCFunc(Downcast<Function>(ref));
-    } else if (ref->IsInstance<IRModuleNode>()) {
-      IRModule mod = Downcast<IRModule>(ref);
-      for (const auto& it : mod->functions) {
-        GenCFunc(Downcast<Function>(it.second));
-      }
-    } else {
-      LOG(FATAL) << "The input ref is expected to be a Relay function or module"
-                 << "\n";
-    }
+    CHECK(ref->IsInstance<FunctionNode>());
+    auto res = GenCFunc(Downcast<Function>(ref));
+    std::string code = code_stream_.str();
+
+    String sym = std::get<0>(res);
+    Array<String> variables = std::get<1>(res);
 
-    // Create a CSourceModule
+    // Create a CSource module
     const auto* pf = runtime::Registry::Get("runtime.CSourceModuleCreate");
     CHECK(pf != nullptr) << "Cannot find csource module to create the external runtime module";
-    return (*pf)(code_stream_.str(), "cc");
+    return (*pf)(code, "c", sym, variables);
   }
 
  private:
index 3a3c486..32ab150 100644 (file)
@@ -110,8 +110,10 @@ class CodegenCBase {
    *
    * \code
    *
+   * Array<NDArray> foo_consts;
+   *
    * // An example code for the generated C function.
-   * extern "C" void foo_wrapper_(DLTensor* arg0,
+   * extern "C" int foo_wrapper_(DLTensor* arg0,
    *                              DLTensor* arg1,
    *                              DLTensor* out) {
    *   foo_(static_cast<float*>(arg0->data),
@@ -122,10 +124,17 @@ class CodegenCBase {
    *
    * TVM_DLL_EXPORT_TYPED_FUNC(foo, foo_wrapper_);
    *
+   * int foo_init_wrapper_(Array<NDArray> arr) {
+   *   foo_consts = arr;
+   *   return 0;
+   * }
+   *
+   * TVM_DLL_EXPORT_TYPED_FUNC(__init_foo, foo_init_wrapper_);
+   *
    * \endcode
    */
   void GenerateBackendCFunc(const std::string& func_name, const Array<Var>& args,
-                            const std::vector<Output>& outs) {
+                            const std::string& const_arr_name, const std::vector<Output>& outs) {
     // Print signature
     code_stream_ << "\n";
     code_stream_ << "extern \"C\" int " << func_name << "_wrapper_(";
@@ -163,6 +172,18 @@ class CodegenCBase {
     // Generate the macro
     code_stream_ << "TVM_DLL_EXPORT_TYPED_FUNC(" << func_name << ", " << func_name
                  << "_wrapper_);\n\n";
+
+    if (!const_arr_name.empty()) {
+      code_stream_ << "int " << func_name << "_init_wrapper_(Array<NDArray> arr) {\n";
+      EnterScope();
+      PrintIndents();
+      code_stream_ << func_name << "_consts = arr;\n";
+      code_stream_ << "return 0;\n";
+      ExitScope();
+      code_stream_ << "}\n\n";
+      code_stream_ << "TVM_DLL_EXPORT_TYPED_FUNC(__init_" << func_name << ", " << func_name
+                   << "_init_wrapper_);\n\n";
+    }
   }
 
   /*!
@@ -190,7 +211,12 @@ class CodegenCBase {
    */
   std::string JitImpl(const std::string& ext_func_id, const Array<Var>& args,
                       const std::vector<std::string>& buf_decl,
-                      const std::vector<std::string>& body, const std::vector<Output>& outs) {
+                      const std::vector<std::string>& body, const std::string& const_arr_name,
+                      const std::vector<Output>& outs) {
+    // Create a declaration for global ndarrays that contain constant data.
+    if (!const_arr_name.empty()) {
+      code_stream_ << const_arr_name << "\n\n";
+    }
     // Create the signature. For example, it could be:
     // extern "C" void dnnl_0_(float* in0, float* in1, float* out0, float* out1) {}
     code_stream_ << "extern \"C\" void " << ext_func_id << "_(";
@@ -236,7 +262,7 @@ class CodegenCBase {
     code_stream_ << "}\n";
 
     // Create the wrapper to call the ext_func
-    this->GenerateBackendCFunc(ext_func_id, args, outs);
+    this->GenerateBackendCFunc(ext_func_id, args, const_arr_name, outs);
     return code_stream_.str();
   }
 
@@ -275,6 +301,55 @@ class CodegenCBase {
     return dtype;
   }
 
+  /*!
+   * \brief Creates a checker to check if the NDArray pool is initialized
+   *
+   * \param symobl The Symbol of the current function
+   *
+   * \return The created checker
+   */
+  std::string CreateInitChecker(const std::string& symbol) const {
+    std::ostringstream oss;
+    oss << "CHECK(!" << symbol
+        << "_consts.empty()) << \"C source module hasn't been initialized.\";\n";
+    return oss.str();
+  }
+
+  /*!
+   * \brief Generates the global ndarray pool declaration
+   *
+   * \param symobl The Symbol of the current function
+   *
+   * \return The created declaration
+   */
+  std::string CreateNDArrayPool(const std::string& symbol) const {
+    return "Array<NDArray> " + symbol + "_consts;";
+  }
+
+  /*!
+   * \brief Generates the reference to the data of a constant ndarray
+   *
+   * \param symobl The Symbol of the current function
+   * \param symobl const_id The index of the constant
+   *
+   * \return The created reference
+   */
+  std::string CreateDataReference(const std::string& symbol, int const_id) const {
+    return "static_cast<float*>(" + symbol + "_consts[" + std::to_string(const_id) + "]->data)";
+  }
+
+  /*!
+   * \brief Returns the variable name for a constant variable
+   *
+   * \param symobl The Symbol of the current function
+   * \param symobl const_id The index of the constant
+   *
+   * \return The created variable name
+   */
+  std::string CreateConstVar(const std::string& symbol, int const_id) const {
+    return symbol + "_const_" + std::to_string(const_id++);
+  }
+
   /*! \brief The external function source code stream. */
   std::ostringstream code_stream_;
 
index 3f9ad7c..60138ae 100644 (file)
@@ -165,33 +165,27 @@ class CodegenDNNL : public MemoizedExprTranslator<std::vector<Output>>, public C
 
   std::vector<Output> VisitExpr_(const ConstantNode* cn) final {
     Output output;
-    output.name = "const_" + std::to_string(const_idx_++);
+    // Get const: static_cast<float*>(dnnl_0_consts[0]->data)
+    output.name = CreateDataReference(ext_func_id_, const_idx_);
     output.dtype = "float";
 
-    runtime::NDArray array = cn->data;
+    // Generate the global variable for needed ndarrays
+    if (const_array_name_.empty()) {
+      const_array_name_ = CreateNDArrayPool(ext_func_id_);
+      std::string checker = CreateInitChecker(ext_func_id_);
+      ext_func_body_.insert(ext_func_body_.begin(), checker);
+    }
 
-    // Get the number of elements.
-    int64_t num_elems = 1;
-    for (auto i : array.Shape()) num_elems *= i;
+    // Give the ndarray a unique name to ease the initialization of it at
+    // runtime.
+    std::string const_var_name = CreateConstVar(ext_func_id_, const_idx_);
+    const_vars_.push_back(const_var_name);
+    const_idx_++;
 
     const auto* type_node = cn->checked_type().as<TensorTypeNode>();
     CHECK(type_node);
     CHECK_EQ(GetDtypeString(type_node), "float") << "Only float is supported for now.";
 
-    std::ostringstream buf_stream;
-    const float* ptr = static_cast<float*>(array->data);
-
-    // Allocate large arrays on the static section to avoid stakc overflow.
-    // Note that this would probably increase compilation time as the source
-    // file could be really large.
-    buf_stream << "static float " << output.name << "[" << num_elems << "] = {";
-    for (int64_t i = 0; i < num_elems - 1; i++) {
-      buf_stream << ptr[i] << ",";
-    }
-    if (num_elems > 0) buf_stream << ptr[num_elems - 1];
-    buf_stream << "};\n";
-
-    ext_func_body.insert(ext_func_body.begin(), buf_stream.str());
     return {output};
   }
 
@@ -204,12 +198,12 @@ class CodegenDNNL : public MemoizedExprTranslator<std::vector<Output>>, public C
     }
 
     buf_decl_.insert(buf_decl_.end(), ret.buffers.begin(), ret.buffers.end());
-    ext_func_body.push_back(ret.decl);
+    ext_func_body_.push_back(ret.decl);
     return ret.outputs;
   }
 
   std::string JIT(const std::vector<Output>& out) {
-    return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out);
+    return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body_, const_array_name_, out);
   }
 
  private:
@@ -341,10 +335,16 @@ class CodegenDNNL : public MemoizedExprTranslator<std::vector<Output>>, public C
   int const_idx_{0};
   /*! \brief The arguments used by a wrapped function that calls DNNL kernels. */
   Array<Var> ext_func_args_;
-  /*! \brief statement of the function that will be compiled using DNNL kernels. */
-  std::vector<std::string> ext_func_body;
+  /*! \brief Statement of the function that will be compiled using DNNL kernels. */
+  std::vector<std::string> ext_func_body_;
+  /*! \brief The array declared to store the constant values. */
+  std::string const_array_name_;
   /*! \brief The declaration of intermeidate buffers. */
   std::vector<std::string> buf_decl_;
+  /*! \brief The variable name to constant mapping. */
+  Array<String> const_vars_;
+
+  friend class DNNLModuleCodegen;
 };
 
 /*!
@@ -355,7 +355,7 @@ class CodegenDNNL : public MemoizedExprTranslator<std::vector<Output>>, public C
 class DNNLModuleCodegen : public CSourceModuleCodegenBase {
  public:
   // Create a corresponding DNNL function for the given relay Function.
-  void GenDNNLFunc(const Function& func) {
+  std::pair<std::string, Array<String>> GenDNNLFunc(const Function& func) {
     CHECK(func.defined()) << "Input error: expect a Relay function.";
 
     // Record the external symbol for runtime lookup.
@@ -364,6 +364,8 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase {
     CodegenDNNL builder(sid);
     auto out = builder.VisitExpr(func->body);
     code_stream_ << builder.JIT(out);
+
+    return {sid, builder.const_vars_};
   }
 
   /*!
@@ -382,32 +384,29 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase {
     code_stream_ << "#include <cstdint>\n";
     code_stream_ << "#include <cstdlib>\n";
     code_stream_ << "#include <cstring>\n";
+    code_stream_ << "#include <vector>\n";
     code_stream_ << "#include <tvm/runtime/c_runtime_api.h>\n";
+    code_stream_ << "#include <tvm/runtime/container.h>\n";
     code_stream_ << "#include <tvm/runtime/packed_func.h>\n";
     code_stream_ << "#include <dlpack/dlpack.h>\n";
     // dnnl_kernel file is saved under src/runtime/contrib/dnnl so that we don't
     // expose it to ordinary users. To make export_library use it, users need to
     // pass -I${PATH_TO_TVM}/src/runtime/contrib
     code_stream_ << "#include <dnnl/dnnl_kernel.h>\n";
+    code_stream_ << "using namespace tvm::runtime;\n";
     code_stream_ << "using namespace tvm::runtime::contrib;\n";
     code_stream_ << "\n";
 
-    if (ref->IsInstance<FunctionNode>()) {
-      GenDNNLFunc(Downcast<Function>(ref));
-    } else if (ref->IsInstance<IRModuleNode>()) {
-      IRModule mod = Downcast<IRModule>(ref);
-      for (const auto& it : mod->functions) {
-        GenDNNLFunc(Downcast<Function>(it.second));
-      }
-    } else {
-      LOG(FATAL) << "The input ref is expected to be a Relay function or module"
-                 << "\n";
-    }
+    CHECK(ref->IsInstance<FunctionNode>());
+    auto res = GenDNNLFunc(Downcast<Function>(ref));
+    std::string code = code_stream_.str();
+    String sym = std::get<0>(res);
+    Array<String> variables = std::get<1>(res);
 
-    // Create a CSourceModule
+    // Create a CSource module
     const auto* pf = runtime::Registry::Get("runtime.CSourceModuleCreate");
     CHECK(pf != nullptr) << "Cannot find csource module to create the external runtime module";
-    return (*pf)(code_stream_.str(), "cc");
+    return (*pf)(code, "c", sym, variables);
   }
 
  private:
index 4226cc8..bc8b390 100644 (file)
@@ -368,6 +368,14 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator<std::vector<G
       CCacheKey key = (*pf0)(func, target);
       CachedFunc ext_func = (*pf1)(compile_engine_, key);
       CHECK(ext_func.defined()) << "External function is not defined.";
+
+      // Step into the functions that are handled by external codegen to
+      // collect metadata.
+      const auto name_node = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+      std::string symobl = std::string(name_node.value());
+      ConstantUpdater const_visit(symobl, &params_);
+      const_visit(func);
+
       return GraphAddCallNode(op, ext_func->func_name, ext_func->func_name);
     }
 
index 4475d43..cac6f55 100644 (file)
@@ -44,6 +44,26 @@ namespace relay {
 namespace backend {
 
 /*!
+ * \brief A helper to expand the params by adding the ones used in a given expression.
+ */
+struct ConstantUpdater : public ExprVisitor {
+ public:
+  ConstantUpdater(const std::string& symbol,
+                  std::unordered_map<std::string, runtime::NDArray>* params)
+      : symbol_(symbol), params_(params) {}
+
+  void VisitExpr_(const ConstantNode* cn) final {
+    std::string name = symbol_ + "_const_" + std::to_string(const_idx_++);
+    (*params_)[name] = cn->data;
+  }
+
+ private:
+  int const_idx_{0};
+  std::string symbol_;
+  std::unordered_map<std::string, runtime::NDArray>* params_;
+};
+
+/*!
  * \brief A simple wrapper around ExprFunctor for a single argument case.
  *  The result of visit is memoized.
  */
index 81db341..0af1949 100644 (file)
@@ -41,6 +41,7 @@
 #include <tuple>
 #include <vector>
 
+#include "../../../target/source/codegen_source_base.h"
 #include "../../backend/compile_engine.h"
 #include "../../op/op_common.h"
 #include "../../transforms/pass_util.h"
@@ -996,6 +997,10 @@ void VMCompiler::Codegen() {
     mod.CopyOnWrite();
 
     if (target_str == "ext_dev") {
+      // Collect metadata in functions that are handled by external codegen.
+      CHECK(mod->ContainGlobalVar(cfunc->func_name));
+      backend::ConstantUpdater const_visit(cfunc->func_name, &params_);
+      const_visit(Downcast<Function>(mod->Lookup(cfunc->func_name)));
       continue;
     } else if (funcs.count(target_str) == 0) {
       funcs.emplace(target_str, mod);
@@ -1006,29 +1011,20 @@ void VMCompiler::Codegen() {
 
   auto compile_engine = CompileEngine::Global();
   auto ext_mods = compile_engine->LowerExternalFunctions();
-  runtime::Module mod;
   if (funcs.size() > 0) {
     Map<String, IRModule> build_funcs;
     for (const auto& i : funcs) {
       build_funcs.Set(i.first, i.second);
     }
-    mod = tvm::build(build_funcs, target_host_);
-    CHECK(mod.operator->());
+    exec_->lib = tvm::build(build_funcs, target_host_);
   } else {
-    CHECK_EQ(ext_mods.size(), 1U)
-        << "Expect to have a TVM DSOModule when multiple runtime modules exist";
+    // There is no function handled by TVM. We create a virtual master module
+    // to make sure a DSO module will be also available.
+    exec_->lib = codegen::CSourceModuleCreate(";", "");
   }
   if (!ext_mods.empty()) {
-    if (funcs.size() == 0) {
-      mod = ext_mods[0];
-    } else {
-      // Import all external runtime modules.
-      for (auto it : ext_mods) {
-        mod.Import(it);
-      }
-    }
+    exec_->lib = codegen::CreateMetadataModule(params_, exec_->lib, ext_mods);
   }
-  exec_->lib = mod;
 }
 
 runtime::Module CreateVMCompiler() {
index 59bfb68..146c097 100644 (file)
@@ -198,7 +198,7 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) {
   CHECK(size == names.size()) << "Invalid parameters file format";
   for (size_t i = 0; i < size; ++i) {
     int in_idx = GetInputIndex(names[i]);
-    CHECK_GE(in_idx, 0) << "Found param for non-existent input: " << names[i];
+    if (in_idx < 0) continue;
     uint32_t eid = this->entry_id(input_nodes_[in_idx], 0);
     CHECK_LT(eid, data_entry_.size());
 
@@ -222,7 +222,7 @@ void GraphRuntime::ShareParams(const GraphRuntime& other, dmlc::Stream* strm) {
   CHECK(size == names.size()) << "Invalid parameters file format";
   for (size_t i = 0; i < size; ++i) {
     int in_idx = GetInputIndex(names[i]);
-    CHECK_GE(in_idx, 0) << "Found param for non-existent input: " << names[i];
+    if (in_idx < 0) continue;
     uint32_t eid = this->entry_id(input_nodes_[in_idx], 0);
     CHECK_LT(eid, data_entry_.size());
     CHECK_EQ(data_entry_[eid].use_count(), 1);
@@ -422,8 +422,9 @@ PackedFunc GraphRuntime::GetFunction(const std::string& name,
       } else {
         in_idx = args[0];
       }
-      CHECK_GE(in_idx, 0);
-      *rv = this->GetInput(in_idx);
+      if (in_idx >= 0) {
+        *rv = this->GetInput(in_idx);
+      }
     });
   } else if (name == "get_num_outputs") {
     return PackedFunc(
index 451c0e8..03dba39 100644 (file)
 
 #include <dmlc/io.h>
 #include <dmlc/json.h>
+#include <tvm/runtime/module.h>
+#include <tvm/runtime/ndarray.h>
 #include <tvm/runtime/packed_func.h>
 
 #include <string>
+#include <unordered_map>
 #include <vector>
 
 #include "runtime_base.h"
 namespace tvm {
 namespace runtime {
 
+/*!
+ * \brief Create a metadata module object.
+ *
+ * \param metadata The variable name to ndarray mapping.
+ * \param sym_vars The symbol to the list of required constant variables
+ * mapping.
+ *
+ * \return The created metadata module.
+ */
+Module MetadataModuleCreate(
+    const std::unordered_map<std::string, NDArray>& metadata,
+    const std::unordered_map<std::string, std::vector<std::string>>& sym_vars);
+
 /*! \brief function information needed by device */
 struct FunctionInfo {
   std::string name;
diff --git a/src/runtime/metadata_module.cc b/src/runtime/metadata_module.cc
new file mode 100644 (file)
index 0000000..cf3d547
--- /dev/null
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/runtime/metadata_module.cc
+ * \brief A wrapper for initializing imported modules using metadata. This
+ * module is intended to be used by various runtime in the TVM stack, i.e.
+ * graph runtime, relay VM, AOT runtime, and various user defined runtimes. It
+ * paves the way to separate the code and metedata, which makes compilation
+ * and/or interpretation more convenient. In addition, the clear separation of
+ * code and metadata significantly reduces the efforts for handling external
+ * codegen and runtimes.
+ */
+#include <tvm/node/container.h>
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+
+#include <cstdint>
+#include <sstream>
+
+#include "meta_data.h"
+
+namespace tvm {
+namespace runtime {
+
+/*!
+ * \brief The metadata module is designed to manage initialization of the
+ * imported submodules.
+ */
+class MetadataModuleNode : public ModuleNode {
+ public:
+  MetadataModuleNode(const std::unordered_map<std::string, NDArray>& metadata,
+                     const std::unordered_map<std::string, std::vector<std::string>>& sym_vars)
+      : metadata_(metadata), sym_vars_(sym_vars) {}
+
+  PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
+    // Initialize and memoize the module.
+    // Usually, we have some warmup runs. The module initialization should be
+    // done at this stage. Therefore, runtime overhead is not a concern.
+    if (initialized_.count(name) == 0) {
+      this->InitSubModule(name);
+      initialized_.emplace(name);
+    }
+
+    // Run the module.
+    // Normally we would only have a limited number of submodules. The runtime
+    // symobl lookup overhead should be minimal.
+    CHECK(!this->imports().empty());
+    for (Module it : this->imports()) {
+      PackedFunc pf = it.GetFunction(name);
+      if (pf != nullptr) return pf;
+    }
+    return PackedFunc(nullptr);
+  }
+
+  const char* type_key() const { return "metadata"; }
+
+  /*!
+   * \brief Get the list of metadata that is required by the given module.
+   * \param symbol The symbol that is being queried.
+   * \return The list of needed NDArray.
+   */
+  Array<NDArray> GetRequiredMetadata(const std::string& symbol) {
+    Array<NDArray> ret;
+    CHECK_GT(sym_vars_.count(symbol), 0U) << "No symbol is recorded for " << symbol;
+    std::vector<std::string> vars = sym_vars_[symbol];
+    for (const auto& it : vars) {
+      CHECK_GT(metadata_.count(it), 0U) << "Found not recorded constant variable: " << it;
+      ret.push_back(metadata_[it]);
+    }
+    return ret;
+  }
+
+  /*!
+   * \brief Initialize each imported module.
+   * \param symobl The symbol used for initializing a module. It is also used
+   * for runtime lookup.
+   *
+   * \note  A module could be like the following:
+   *  MetadataModuleNode (contains all the metadata)
+   *    - CSourceModule
+   *    - JSON runtime module
+   *
+   *  The initializer iterates through the imported modules and intilizes the
+   *  found module accordingly by passing the needed metadata into it.
+   */
+  void InitSubModule(const std::string& symbol) {
+    PackedFunc init(nullptr);
+    for (Module it : this->imports()) {
+      // Get the initialization function from the imported modules.
+      std::string init_name = "__init_" + symbol;
+      init = it.GetFunction(init_name, false);
+      if (init != nullptr) {
+        auto md = GetRequiredMetadata(symbol);
+        // Initialize the module with metadata.
+        int ret = init(md);
+        // Report the error if initialization is failed.
+        CHECK_EQ(ret, 0) << TVMGetLastError();
+        break;
+      }
+    }
+  }
+
+  void SaveToBinary(dmlc::Stream* stream) final {
+    std::vector<std::string> variables;
+    std::vector<NDArray> metadata;
+    for (const auto& it : metadata_) {
+      String var_name = it.first;
+      variables.push_back(var_name);
+      metadata.push_back(it.second);
+    }
+
+    // Save all variables in the function.
+    stream->Write(variables);
+    // Save all constant data.
+    uint64_t sz = static_cast<uint64_t>(metadata.size());
+    stream->Write(sz);
+    for (uint64_t i = 0; i < sz; i++) {
+      metadata[i].Save(stream);
+    }
+
+    // Save the symbol to list of required constant variables mapping
+    std::vector<std::string> symbols;
+    std::vector<std::vector<std::string>> const_vars;
+    for (const auto& it : sym_vars_) {
+      symbols.push_back(it.first);
+      const_vars.push_back(it.second);
+    }
+
+    stream->Write(symbols);
+    sz = static_cast<uint64_t>(sym_vars_.size());
+    stream->Write(sz);
+    for (uint64_t i = 0; i < sz; i++) {
+      stream->Write(const_vars[i]);
+    }
+  }
+
+  static Module LoadFromBinary(void* strm) {
+    dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
+
+    // Load the variables.
+    std::vector<std::string> variables;
+    CHECK(stream->Read(&variables)) << "Loading variables failed";
+    uint64_t sz;
+    CHECK(stream->Read(&sz, sizeof(sz))) << "Loading metadata size failed";
+    CHECK_EQ(static_cast<size_t>(sz), variables.size())
+        << "The number of variables and ndarray counts must match";
+    // Load the list of ndarray.
+    std::vector<NDArray> arrays;
+    for (uint64_t i = 0; i < sz; i++) {
+      NDArray temp;
+      temp.Load(stream);
+      arrays.push_back(temp);
+    }
+
+    std::unordered_map<std::string, NDArray> metadata;
+    for (uint64_t i = 0; i < sz; i++) {
+      CHECK_EQ(metadata.count(variables[i]), 0U);
+      metadata[variables[i]] = arrays[i];
+    }
+
+    // Load the symbol to list of required constant variables mapping
+    std::vector<std::string> symbols;
+    CHECK(stream->Read(&symbols)) << "Loading symbols failed";
+    CHECK(stream->Read(&sz, sizeof(sz))) << "Loading number of symbols failed";
+    CHECK_EQ(static_cast<size_t>(sz), symbols.size());
+    std::vector<std::vector<std::string>> const_vars;
+    for (uint64_t i = 0; i < sz; i++) {
+      std::vector<std::string> vars;
+      CHECK(stream->Read(&vars)) << "Loading const variables failed";
+      const_vars.push_back(vars);
+    }
+
+    std::unordered_map<std::string, std::vector<std::string>> sym_vars;
+    for (uint64_t i = 0; i < sz; i++) {
+      sym_vars[symbols[i]] = const_vars[i];
+    }
+
+    auto n = make_object<MetadataModuleNode>(metadata, sym_vars);
+    return Module(n);
+  }
+
+ private:
+  /*!
+   * \brief Record if a module is initialized. It is needed by imported
+   * modules using execution engine.
+   */
+  std::unordered_set<std::string> initialized_;
+  /*! \brief Variable name to NDArray mapping. */
+  std::unordered_map<std::string, NDArray> metadata_;
+  /*! \brief Symbol name to required constant variables mapping. */
+  std::unordered_map<std::string, std::vector<std::string>> sym_vars_;
+};
+
+Module MetadataModuleCreate(
+    const std::unordered_map<std::string, NDArray>& metadata,
+    const std::unordered_map<std::string, std::vector<std::string>>& sym_vars) {
+  auto n = make_object<MetadataModuleNode>(metadata, sym_vars);
+  return Module(n);
+}
+
+TVM_REGISTER_GLOBAL("runtime.module.loadbinary_metadata")
+    .set_body_typed(MetadataModuleNode::LoadFromBinary);
+}  // namespace runtime
+}  // namespace tvm
index 3901659..7e5e403 100644 (file)
@@ -135,9 +135,26 @@ runtime::Module SourceModuleCreate(std::string code, std::string fmt);
 /*!
  * \brief Create a C source module for viewing and compiling GCC code.
  * \param code The code to be viewed.
- * \param fmt The code. format.
+ * \param fmt The code format.
+ * \param symbol The symbol that the c source module represents.
+ * \param const_vars. The constant variables that the c source module needs.
+ * \return The created module.
+ */
+runtime::Module CSourceModuleCreate(const String& code, const String& fmt,
+                                    const String& symbol = "",
+                                    const Array<String>& const_vars = {});
+
+/*!
+ * \brief Wrap the submodules in a metadata module.
+ * \param params The variable to constant mapping that is collected by the host
+ *        module.
+ * \param dso_module The host module to be wrapped.
+ * \param modules The modules to be wrapped.
+ * \return The wrapped module.
  */
-runtime::Module CSourceModuleCreate(std::string code, std::string fmt);
+runtime::Module CreateMetadataModule(
+    const std::unordered_map<std::string, runtime::NDArray>& params,
+    const runtime::Module& dso_module, const Array<runtime::Module>& modules);
 
 /*!
  * \brief Create a source module for viewing and limited saving for device.
index ba7f075..1e201e5 100644 (file)
@@ -21,6 +21,7 @@
  * \file source_module.cc
  * \brief Source code module, only for viewing
  */
+#include <tvm/runtime/ndarray.h>
 #include <tvm/runtime/packed_func.h>
 #include <tvm/runtime/registry.h>
 
@@ -40,6 +41,44 @@ using runtime::GetFileFormat;
 using runtime::GetMetaFilePath;
 using runtime::SaveBinaryToFile;
 
+/*!
+ * \brief Create a metadata module wrapper. The helper is used by different
+ *        codegens, such as graph runtime codegen and the vm compiler.
+ *
+ * \param params The metadata for initialization of all modules.
+ * \param dso_module The DSO module that contains TVM primitives.
+ * \param modules The submodules that will be wrapped, e.g. CSource modules that
+ *        contain vendor library calls or customized runtime modules.
+ *
+ * \return The created metadata module that manages initialization of metadata.
+ */
+runtime::Module CreateMetadataModule(
+    const std::unordered_map<std::string, runtime::NDArray>& params,
+    const runtime::Module& dso_module, const Array<runtime::Module>& modules) {
+  // Wrap all submodules in the initialization wrapper.
+  std::unordered_map<std::string, std::vector<std::string>> sym_metadata;
+  for (runtime::Module it : modules) {
+    CHECK_EQ(it->type_key(), "c") << "Only csource submodule is handled for now";
+    String symbol = it.GetFunction("get_symbol")();
+    Array<String> variables = it.GetFunction("get_const_vars")();
+    std::vector<std::string> arrays;
+    for (size_t i = 0; i < variables.size(); i++) {
+      arrays.push_back(variables[i].operator std::string());
+    }
+    CHECK_EQ(sym_metadata.count(symbol), 0U) << "Found duplicated symbol: " << symbol;
+    sym_metadata[symbol] = arrays;
+  }
+
+  // Wrap the modules.
+  runtime::Module init_m = runtime::MetadataModuleCreate(params, sym_metadata);
+  init_m.Import(dso_module);
+  for (const auto& it : modules) {
+    init_m.Import(it);
+  }
+
+  return init_m;
+}
+
 // Simulator function
 class SourceModuleNode : public runtime::ModuleNode {
  public:
@@ -67,13 +106,22 @@ runtime::Module SourceModuleCreate(std::string code, std::string fmt) {
 // Simulator function
 class CSourceModuleNode : public runtime::ModuleNode {
  public:
-  CSourceModuleNode(std::string code, std::string fmt) : code_(code), fmt_(fmt) {}
+  CSourceModuleNode(const std::string& code, const std::string& fmt, const std::string& symbol,
+                    const Array<String>& const_vars)
+      : code_(code), fmt_(fmt), symbol_(symbol), const_vars_(const_vars) {}
   const char* type_key() const { return "c"; }
 
   PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
-    LOG(FATAL) << "C Source module cannot execute, to get executable module"
-               << " build TVM with \'" << fmt_ << "\' runtime support";
-    return PackedFunc();
+    if (name == "get_symbol") {
+      return PackedFunc(
+          [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->symbol_; });
+    } else if (name == "get_const_vars") {
+      return PackedFunc(
+          [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->const_vars_; });
+    } else {
+      LOG(FATAL) << "Unknown packed function: " << name;
+      return PackedFunc(nullptr);
+    }
   }
 
   std::string GetSource(const std::string& format) final { return code_; }
@@ -92,10 +140,14 @@ class CSourceModuleNode : public runtime::ModuleNode {
  protected:
   std::string code_;
   std::string fmt_;
+  std::string symbol_;
+  Array<String> const_vars_;
 };
 
-runtime::Module CSourceModuleCreate(std::string code, std::string fmt) {
-  auto n = make_object<CSourceModuleNode>(code, fmt);
+runtime::Module CSourceModuleCreate(const String& code, const String& fmt, const String& symbol,
+                                    const Array<String>& const_vars) {
+  auto n = make_object<CSourceModuleNode>(code.operator std::string(), fmt.operator std::string(),
+                                          symbol.operator std::string(), const_vars);
   return runtime::Module(n);
 }
 
@@ -154,6 +206,10 @@ runtime::Module DeviceSourceModuleCreate(
 
 TVM_REGISTER_GLOBAL("runtime.SourceModuleCreate").set_body_typed(SourceModuleCreate);
 
-TVM_REGISTER_GLOBAL("runtime.CSourceModuleCreate").set_body_typed(CSourceModuleCreate);
+TVM_REGISTER_GLOBAL("runtime.CSourceModuleCreate")
+    .set_body_typed([](String code, String fmt, String symbol, Array<String> const_vars) {
+      return CSourceModuleCreate(code, fmt, symbol, const_vars);
+    });
+
 }  // namespace codegen
 }  // namespace tvm
index 665cb7b..a82f1a5 100644 (file)
@@ -78,11 +78,10 @@ def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output
             # Its possible for some onnx inputs to not be needed in the tvm
             # module, confirm its present before setting.
             try:
-                m.get_input(input_names[i])
+                m.set_input(input_names[i], tvm.nd.array(
+                    input_data[i].astype(input_data[i].dtype)))
             except:
                 continue
-            m.set_input(input_names[i], tvm.nd.array(
-                input_data[i].astype(input_data[i].dtype)))
     else:
         m.set_input(input_names, tvm.nd.array(
             input_data.astype(input_data.dtype)))
index c449ce3..6771bd1 100644 (file)
@@ -259,9 +259,52 @@ def test_extern_dnnl():
                  (1, 32, 14, 14), ref_res.asnumpy(), tol=1e-5)
 
 
+def test_extern_dnnl_const():
+    if not tvm.get_global_func("relay.ext.dnnl", True):
+        print("skip because DNNL codegen is not available")
+        return
+
+    dtype = 'float32'
+    ishape = (1, 32, 14, 14)
+    w1shape = (32, 1, 3, 3)
+    data0 = relay.var('data0', shape=(ishape), dtype=dtype)
+    w_data = np.random.uniform(0, 1, w1shape).astype(dtype)
+
+    data1 = relay.var('data0', shape=(ishape), dtype=dtype)
+    weight1 = relay.const(w_data, dtype=dtype)
+    weight2 = relay.const(w_data, dtype=dtype)
+    depthwise_conv2d_1 = relay.nn.conv2d(data1,
+                                         weight1,
+                                         kernel_size=(3, 3),
+                                         padding=(1, 1),
+                                         groups=32)
+    depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1,
+                                         weight2,
+                                         kernel_size=(3, 3),
+                                         padding=(1, 1),
+                                         groups=32)
+    out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)
+
+    f = relay.Function([data1], out)
+    ref_mod = tvm.IRModule()
+    ref_mod['main'] = f
+
+    f = set_external_func_attr(f, "dnnl", "dnnl_0")
+    call = relay.Call(f, [data0])
+    mod = tvm.IRModule.from_expr(call)
+
+    i_data = np.random.uniform(0, 1, ishape).astype(dtype)
+
+    ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu())
+    ref_res = ref_ex.evaluate()(i_data)
+    check_result(mod, {"data0": i_data},
+                 (1, 32, 14, 14), ref_res.asnumpy(), tol=1e-5)
+
+
 if __name__ == "__main__":
     test_multi_node_subgraph()
     test_extern_gcc_single_op()
     test_extern_gcc_single_op_int()
     test_extern_gcc()
     test_extern_dnnl()
+    test_extern_dnnl_const()
index 3920923..7928e4d 100644 (file)
@@ -109,7 +109,8 @@ def generate_csource_module():
     TVM_DLL_EXPORT_TYPED_FUNC(json_rt_0, ccompiler_wrapper_0_);
 
     '''
-    csource_module = tvm.runtime._ffi_api.CSourceModuleCreate(code, "cc")
+    csource_module = tvm.runtime._ffi_api.CSourceModuleCreate(code, "cc", "",
+                                                              None)
     return csource_module
 
 
@@ -175,7 +176,8 @@ def generate_engine_module():
     '''
 
     gen_json_engine()
-    csource_module = tvm.runtime._ffi_api.CSourceModuleCreate(code, "cc")
+    csource_module = tvm.runtime._ffi_api.CSourceModuleCreate(code, "cc", "",
+                                                              None)
     return csource_module
 
 
index 8473a67..8ee197d 100644 (file)
@@ -54,7 +54,8 @@ def generate_engine_module():
         '''
     import tvm.runtime._ffi_api
     gen_engine_header()
-    csource_module = tvm.runtime._ffi_api.CSourceModuleCreate(code, "cc")
+    csource_module = tvm.runtime._ffi_api.CSourceModuleCreate(code, "cc", "",
+                                                              None)
     return csource_module