From eacfe890669d026c3d3aea4d03f4f773819242dd Mon Sep 17 00:00:00 2001 From: Zhi <5145158+zhiics@users.noreply.github.com> Date: Thu, 18 Jun 2020 15:18:29 -0700 Subject: [PATCH] [RUNTIME] Introduce MetadataModule to separate code compilation/interpretation and weight initialization (#5770) --- python/tvm/contrib/graph_runtime.py | 7 +- python/tvm/runtime/vm.py | 13 +- src/relay/backend/build_module.cc | 7 +- src/relay/backend/compile_engine.cc | 38 ++-- src/relay/backend/contrib/codegen_c/codegen.cc | 85 ++++---- src/relay/backend/contrib/codegen_c/codegen_c.h | 83 +++++++- src/relay/backend/contrib/dnnl/codegen.cc | 73 ++++--- src/relay/backend/graph_runtime_codegen.cc | 8 + src/relay/backend/utils.h | 20 ++ src/relay/backend/vm/compiler.cc | 24 +-- src/runtime/graph/graph_runtime.cc | 9 +- src/runtime/meta_data.h | 16 ++ src/runtime/metadata_module.cc | 222 +++++++++++++++++++++ src/target/source/codegen_source_base.h | 21 +- src/target/source/source_module.cc | 70 ++++++- tests/python/frontend/onnx/test_forward.py | 5 +- tests/python/relay/test_external_codegen.py | 43 ++++ tests/python/relay/test_external_runtime.py | 6 +- .../python/unittest/test_runtime_module_export.py | 3 +- 19 files changed, 609 insertions(+), 144 deletions(-) create mode 100644 src/runtime/metadata_module.cc diff --git a/python/tvm/contrib/graph_runtime.py b/python/tvm/contrib/graph_runtime.py index 740d1c3..9b714a8 100644 --- a/python/tvm/contrib/graph_runtime.py +++ b/python/tvm/contrib/graph_runtime.py @@ -162,7 +162,12 @@ class GraphModule(object): keys = list(params.keys()) keys.sort(key=lambda x: -np.prod(params[x].shape)) for k in keys: - self._get_input(k).copyfrom(params[k]) + # TODO(zhiics) Skip the weights for submodule in a better way. + # We should use MetadataModule for initialization and remove + # params from set_input + val = self._get_input(k) + if val: + self._get_input(k).copyfrom(params[k]) def run(self, **input_dict): """Run forward execution of the graph diff --git a/python/tvm/runtime/vm.py b/python/tvm/runtime/vm.py index 2643ff1..8a85051 100644 --- a/python/tvm/runtime/vm.py +++ b/python/tvm/runtime/vm.py @@ -106,7 +106,7 @@ class Executable(object): import numpy as np import tvm -from tvm import te + from tvm import te from tvm import relay # define a simple network. x = relay.var('x', shape=(10, 10)) @@ -309,12 +309,17 @@ class VirtualMachine(object): Named arguments to the function. """ if kwargs: + # kwargs is a super set of the required function parameters. We + # only find the ones that are needed. func_params = self._exec.get_function_params(func_name) new_args = [None] * len(func_params) - assert len(args) + len(kwargs) == len(func_params) + cnt = 0 for k in kwargs: - idx = func_params.index(k) - new_args[idx] = kwargs[k] + if k in func_params: + idx = func_params.index(k) + new_args[idx] = kwargs[k] + cnt += 1 + assert len(args) + cnt == len(func_params) idx = 0 for i, arg in enumerate(new_args): if arg is None: diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 27c55a9..34c3487 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -455,8 +455,11 @@ class RelayBuildModule : public runtime::ModuleNode { } Array ext_mods = graph_codegen_->GetExternalModules(); - // Import all external runtime modules. - for (const auto& it : ext_mods) ret_.mod.Import(it); + // TODO(zhiics) We should be able to completely switch to MetadataModule no + // matter whether there are external modules or not. + if (!ext_mods.empty()) { + ret_.mod = tvm::codegen::CreateMetadataModule(ret_.params, ret_.mod, ext_mods); + } } private: diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 3687b75..2aae854 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -571,7 +571,8 @@ class CompileEngineImpl : public CompileEngineNode { } Array LowerExternalFunctions() { - std::unordered_map ext_mods; + Array ret; + std::unordered_map cached_symbol; std::vector cached_ext_funcs; for (const auto& it : cache_) { auto src_func = it.first->source_func; @@ -580,29 +581,31 @@ class CompileEngineImpl : public CompileEngineNode { auto code_gen = src_func->GetAttr(attr::kCompiler); CHECK(code_gen.defined()) << "No external codegen is set"; std::string code_gen_name = code_gen.value(); - if (ext_mods.find(code_gen_name) == ext_mods.end()) { - ext_mods[code_gen_name] = IRModule({}, {}); - } + cached_ext_funcs.push_back(it.first); + auto symbol_name = src_func->GetAttr(tvm::attr::kGlobalSymbol); CHECK(symbol_name.defined()) << "No external symbol is set for:\n" << AsText(src_func, false); - auto gv = GlobalVar(symbol_name.value()); + + std::string sn = symbol_name.value(); + if (cached_symbol.count(sn)) { + cached_symbol[sn] = code_gen_name; + } else { + CHECK_NE(sn, code_gen_name) + << "Found duplicated symbol: " << sn << " for: " << code_gen_name; + } + + std::string ext_name = "relay.ext." + code_gen_name; + auto pf = tvm::runtime::Registry::Get(ext_name); + CHECK(pf) << "Failed to find the codegen tool for " << ext_name << "\n"; // No need to keep compiler attribute at this point, functions have been // extracted for specific codegen. src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue()); - ext_mods[code_gen_name]->Add(gv, src_func); - cached_ext_funcs.push_back(it.first); - } - } + runtime::Module ext_mod = (*pf)(src_func); - Array ret; - for (const auto& it : ext_mods) { - std::string ext_name = "relay.ext." + it.first; - auto pf = tvm::runtime::Registry::Get(ext_name); - CHECK(pf) << "Failed to find the codegen tool for " << ext_name << "\n"; - runtime::Module ext_mod = (*pf)(it.second); - CHECK(ext_mod.defined()) << "No external runtime is generated."; - ret.push_back(ext_mod); + CHECK(ext_mod.defined()) << "No external runtime is generated."; + ret.push_back(ext_mod); + } } // No need to cache external functions as we collected them all to create @@ -658,6 +661,7 @@ class CompileEngineImpl : public CompileEngineNode { CHECK(name_node.defined()) << "External function has not been attached a name yet."; cache_node->func_name = std::string(name_node.value()); cache_node->target = tvm::target::ext_dev(); + cache_node->funcs->Add(GlobalVar(cache_node->func_name), key->source_func); value->cached_func = CachedFunc(cache_node); return value; } diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc b/src/relay/backend/contrib/codegen_c/codegen.cc index 2968966..c7b5a8d 100644 --- a/src/relay/backend/contrib/codegen_c/codegen.cc +++ b/src/relay/backend/contrib/codegen_c/codegen.cc @@ -25,6 +25,7 @@ #include #include +#include #include "../../utils.h" #include "codegen_c.h" @@ -76,43 +77,29 @@ class CodegenC : public MemoizedExprTranslator>, public Code } std::vector VisitExpr_(const ConstantNode* cn) final { - // Note this is for demonstration purpose. ConstantNode doesn't necessarily - // belong to calls. We need to revisit this when tuples come into play. - std::ostringstream decl_stream; std::ostringstream buf_stream; Output output; - output.name = "const_" + std::to_string(const_idx_++); - - runtime::NDArray array = cn->data; - const auto& shape = array.Shape(); - - // Get the number of elements. - int64_t num_elems = 1; - for (auto i : shape) num_elems *= i; - + // Get const: static_cast(gcc_0_consts[0]->data) + output.name = CreateDataReference(ext_func_id_, const_idx_); const auto* type_node = cn->checked_type().as(); CHECK(type_node); const auto& dtype = GetDtypeString(type_node); - // Define a const buffer: float const_0[64] = {1.0, 2.0, ...}; - // - // Technically, you may need: static float* const_0 = (float*)malloc(4 * 64) - // to avoid possible stack overflow. - buf_stream << dtype << " " << output.name << "[" << num_elems << "] = {"; - if (dtype == "float") { - float* p_flt = static_cast(array->data); - for (int64_t i = 0; i < num_elems - 1; i++) buf_stream << p_flt[i] << ", "; - if (num_elems) buf_stream << p_flt[num_elems - 1]; - } else if (dtype == "int") { - int* p_flt = static_cast(array->data); - for (int64_t i = 0; i < num_elems - 1; i++) buf_stream << p_flt[i] << ", "; - if (num_elems) buf_stream << p_flt[num_elems - 1]; - } else { - LOG(FATAL) << "Only float and int are supported for now."; + + // Generate the global variable for needed ndarrays + if (const_array_name_.empty()) { + const_array_name_ = CreateNDArrayPool(ext_func_id_); + std::string checker = CreateInitChecker(ext_func_id_); + ext_func_body_.insert(ext_func_body_.begin(), checker); } - buf_stream << "};"; - ext_func_body.insert(ext_func_body.begin(), buf_stream.str()); + + CHECK(dtype == "float" || dtype == "int") << "Only float and int are supported for now."; + output.dtype = dtype; + + std::string const_var_name = CreateConstVar(ext_func_id_, const_idx_); + const_vars_.push_back(const_var_name); + const_idx_++; return {output}; } @@ -175,7 +162,7 @@ class CodegenC : public MemoizedExprTranslator>, public Code buf_decl_.push_back(buf_stream.str()); decl_stream << ", " << out << ");"; - ext_func_body.push_back(decl_stream.str()); + ext_func_body_.push_back(decl_stream.str()); // Update output buffer // Note C codegen only handles TensorType. Therefore, we don't flatten @@ -198,7 +185,7 @@ class CodegenC : public MemoizedExprTranslator>, public Code for (auto decl : func_decl_) { code_stream_ << decl << "\n"; } - return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out); + return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body_, const_array_name_, out); } private: @@ -213,16 +200,22 @@ class CodegenC : public MemoizedExprTranslator>, public Code /*! \brief The arguments of a C compiler compatible function. */ Array ext_func_args_; /*! \brief The statements of a C compiler compatible function. */ - std::vector ext_func_body; + std::vector ext_func_body_; + /*! \brief The array declared to store the constant values. */ + std::string const_array_name_; /*! \brief The declaration statements of a C compiler compatible function. */ std::vector func_decl_; /*! \brief The declaration statements of buffers. */ std::vector buf_decl_; + /*! \brief The variable name to constant mapping. */ + Array const_vars_; + + friend class CSourceCodegen; }; class CSourceCodegen : public CSourceModuleCodegenBase { public: - void GenCFunc(const Function& func) { + std::pair> GenCFunc(const Function& func) { CHECK(func.defined()) << "Input error: expect a Relay function."; // Record the external symbol for runtime lookup. @@ -231,14 +224,19 @@ class CSourceCodegen : public CSourceModuleCodegenBase { CodegenC builder(sid); auto out = builder.VisitExpr(func->body); code_stream_ << builder.JIT(out); + + return {sid, builder.const_vars_}; } runtime::Module CreateCSourceModule(const ObjectRef& ref) override { // Create headers code_stream_ << "#include \n"; + code_stream_ << "#include \n"; code_stream_ << "#include \n"; + code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; + code_stream_ << "using namespace tvm::runtime;\n"; // Append some common macro for operator definition. const char* operator_macro = R"op_macro( @@ -262,22 +260,17 @@ class CSourceCodegen : public CSourceModuleCodegenBase { code_stream_ << operator_macro << "\n\n"; - if (ref->IsInstance()) { - GenCFunc(Downcast(ref)); - } else if (ref->IsInstance()) { - IRModule mod = Downcast(ref); - for (const auto& it : mod->functions) { - GenCFunc(Downcast(it.second)); - } - } else { - LOG(FATAL) << "The input ref is expected to be a Relay function or module" - << "\n"; - } + CHECK(ref->IsInstance()); + auto res = GenCFunc(Downcast(ref)); + std::string code = code_stream_.str(); + + String sym = std::get<0>(res); + Array variables = std::get<1>(res); - // Create a CSourceModule + // Create a CSource module const auto* pf = runtime::Registry::Get("runtime.CSourceModuleCreate"); CHECK(pf != nullptr) << "Cannot find csource module to create the external runtime module"; - return (*pf)(code_stream_.str(), "cc"); + return (*pf)(code, "c", sym, variables); } private: diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h index 3a3c486..32ab150 100644 --- a/src/relay/backend/contrib/codegen_c/codegen_c.h +++ b/src/relay/backend/contrib/codegen_c/codegen_c.h @@ -110,8 +110,10 @@ class CodegenCBase { * * \code * + * Array foo_consts; + * * // An example code for the generated C function. - * extern "C" void foo_wrapper_(DLTensor* arg0, + * extern "C" int foo_wrapper_(DLTensor* arg0, * DLTensor* arg1, * DLTensor* out) { * foo_(static_cast(arg0->data), @@ -122,10 +124,17 @@ class CodegenCBase { * * TVM_DLL_EXPORT_TYPED_FUNC(foo, foo_wrapper_); * + * int foo_init_wrapper_(Array arr) { + * foo_consts = arr; + * return 0; + * } + * + * TVM_DLL_EXPORT_TYPED_FUNC(__init_foo, foo_init_wrapper_); + * * \endcode */ void GenerateBackendCFunc(const std::string& func_name, const Array& args, - const std::vector& outs) { + const std::string& const_arr_name, const std::vector& outs) { // Print signature code_stream_ << "\n"; code_stream_ << "extern \"C\" int " << func_name << "_wrapper_("; @@ -163,6 +172,18 @@ class CodegenCBase { // Generate the macro code_stream_ << "TVM_DLL_EXPORT_TYPED_FUNC(" << func_name << ", " << func_name << "_wrapper_);\n\n"; + + if (!const_arr_name.empty()) { + code_stream_ << "int " << func_name << "_init_wrapper_(Array arr) {\n"; + EnterScope(); + PrintIndents(); + code_stream_ << func_name << "_consts = arr;\n"; + code_stream_ << "return 0;\n"; + ExitScope(); + code_stream_ << "}\n\n"; + code_stream_ << "TVM_DLL_EXPORT_TYPED_FUNC(__init_" << func_name << ", " << func_name + << "_init_wrapper_);\n\n"; + } } /*! @@ -190,7 +211,12 @@ class CodegenCBase { */ std::string JitImpl(const std::string& ext_func_id, const Array& args, const std::vector& buf_decl, - const std::vector& body, const std::vector& outs) { + const std::vector& body, const std::string& const_arr_name, + const std::vector& outs) { + // Create a declaration for global ndarrays that contain constant data. + if (!const_arr_name.empty()) { + code_stream_ << const_arr_name << "\n\n"; + } // Create the signature. For example, it could be: // extern "C" void dnnl_0_(float* in0, float* in1, float* out0, float* out1) {} code_stream_ << "extern \"C\" void " << ext_func_id << "_("; @@ -236,7 +262,7 @@ class CodegenCBase { code_stream_ << "}\n"; // Create the wrapper to call the ext_func - this->GenerateBackendCFunc(ext_func_id, args, outs); + this->GenerateBackendCFunc(ext_func_id, args, const_arr_name, outs); return code_stream_.str(); } @@ -275,6 +301,55 @@ class CodegenCBase { return dtype; } + /*! + * \brief Creates a checker to check if the NDArray pool is initialized + * + * \param symobl The Symbol of the current function + * + * \return The created checker + */ + std::string CreateInitChecker(const std::string& symbol) const { + std::ostringstream oss; + oss << "CHECK(!" << symbol + << "_consts.empty()) << \"C source module hasn't been initialized.\";\n"; + return oss.str(); + } + + /*! + * \brief Generates the global ndarray pool declaration + * + * \param symobl The Symbol of the current function + * + * \return The created declaration + */ + std::string CreateNDArrayPool(const std::string& symbol) const { + return "Array " + symbol + "_consts;"; + } + + /*! + * \brief Generates the reference to the data of a constant ndarray + * + * \param symobl The Symbol of the current function + * \param symobl const_id The index of the constant + * + * \return The created reference + */ + std::string CreateDataReference(const std::string& symbol, int const_id) const { + return "static_cast(" + symbol + "_consts[" + std::to_string(const_id) + "]->data)"; + } + + /*! + * \brief Returns the variable name for a constant variable + * + * \param symobl The Symbol of the current function + * \param symobl const_id The index of the constant + * + * \return The created variable name + */ + std::string CreateConstVar(const std::string& symbol, int const_id) const { + return symbol + "_const_" + std::to_string(const_id++); + } + /*! \brief The external function source code stream. */ std::ostringstream code_stream_; diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index 3f9ad7c..60138ae 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -165,33 +165,27 @@ class CodegenDNNL : public MemoizedExprTranslator>, public C std::vector VisitExpr_(const ConstantNode* cn) final { Output output; - output.name = "const_" + std::to_string(const_idx_++); + // Get const: static_cast(dnnl_0_consts[0]->data) + output.name = CreateDataReference(ext_func_id_, const_idx_); output.dtype = "float"; - runtime::NDArray array = cn->data; + // Generate the global variable for needed ndarrays + if (const_array_name_.empty()) { + const_array_name_ = CreateNDArrayPool(ext_func_id_); + std::string checker = CreateInitChecker(ext_func_id_); + ext_func_body_.insert(ext_func_body_.begin(), checker); + } - // Get the number of elements. - int64_t num_elems = 1; - for (auto i : array.Shape()) num_elems *= i; + // Give the ndarray a unique name to ease the initialization of it at + // runtime. + std::string const_var_name = CreateConstVar(ext_func_id_, const_idx_); + const_vars_.push_back(const_var_name); + const_idx_++; const auto* type_node = cn->checked_type().as(); CHECK(type_node); CHECK_EQ(GetDtypeString(type_node), "float") << "Only float is supported for now."; - std::ostringstream buf_stream; - const float* ptr = static_cast(array->data); - - // Allocate large arrays on the static section to avoid stakc overflow. - // Note that this would probably increase compilation time as the source - // file could be really large. - buf_stream << "static float " << output.name << "[" << num_elems << "] = {"; - for (int64_t i = 0; i < num_elems - 1; i++) { - buf_stream << ptr[i] << ","; - } - if (num_elems > 0) buf_stream << ptr[num_elems - 1]; - buf_stream << "};\n"; - - ext_func_body.insert(ext_func_body.begin(), buf_stream.str()); return {output}; } @@ -204,12 +198,12 @@ class CodegenDNNL : public MemoizedExprTranslator>, public C } buf_decl_.insert(buf_decl_.end(), ret.buffers.begin(), ret.buffers.end()); - ext_func_body.push_back(ret.decl); + ext_func_body_.push_back(ret.decl); return ret.outputs; } std::string JIT(const std::vector& out) { - return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out); + return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body_, const_array_name_, out); } private: @@ -341,10 +335,16 @@ class CodegenDNNL : public MemoizedExprTranslator>, public C int const_idx_{0}; /*! \brief The arguments used by a wrapped function that calls DNNL kernels. */ Array ext_func_args_; - /*! \brief statement of the function that will be compiled using DNNL kernels. */ - std::vector ext_func_body; + /*! \brief Statement of the function that will be compiled using DNNL kernels. */ + std::vector ext_func_body_; + /*! \brief The array declared to store the constant values. */ + std::string const_array_name_; /*! \brief The declaration of intermeidate buffers. */ std::vector buf_decl_; + /*! \brief The variable name to constant mapping. */ + Array const_vars_; + + friend class DNNLModuleCodegen; }; /*! @@ -355,7 +355,7 @@ class CodegenDNNL : public MemoizedExprTranslator>, public C class DNNLModuleCodegen : public CSourceModuleCodegenBase { public: // Create a corresponding DNNL function for the given relay Function. - void GenDNNLFunc(const Function& func) { + std::pair> GenDNNLFunc(const Function& func) { CHECK(func.defined()) << "Input error: expect a Relay function."; // Record the external symbol for runtime lookup. @@ -364,6 +364,8 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase { CodegenDNNL builder(sid); auto out = builder.VisitExpr(func->body); code_stream_ << builder.JIT(out); + + return {sid, builder.const_vars_}; } /*! @@ -382,32 +384,29 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase { code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; + code_stream_ << "#include \n"; code_stream_ << "#include \n"; + code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; // dnnl_kernel file is saved under src/runtime/contrib/dnnl so that we don't // expose it to ordinary users. To make export_library use it, users need to // pass -I${PATH_TO_TVM}/src/runtime/contrib code_stream_ << "#include \n"; + code_stream_ << "using namespace tvm::runtime;\n"; code_stream_ << "using namespace tvm::runtime::contrib;\n"; code_stream_ << "\n"; - if (ref->IsInstance()) { - GenDNNLFunc(Downcast(ref)); - } else if (ref->IsInstance()) { - IRModule mod = Downcast(ref); - for (const auto& it : mod->functions) { - GenDNNLFunc(Downcast(it.second)); - } - } else { - LOG(FATAL) << "The input ref is expected to be a Relay function or module" - << "\n"; - } + CHECK(ref->IsInstance()); + auto res = GenDNNLFunc(Downcast(ref)); + std::string code = code_stream_.str(); + String sym = std::get<0>(res); + Array variables = std::get<1>(res); - // Create a CSourceModule + // Create a CSource module const auto* pf = runtime::Registry::Get("runtime.CSourceModuleCreate"); CHECK(pf != nullptr) << "Cannot find csource module to create the external runtime module"; - return (*pf)(code_stream_.str(), "cc"); + return (*pf)(code, "c", sym, variables); } private: diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index 4226cc8..bc8b390 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -368,6 +368,14 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslatorGetAttr(tvm::attr::kGlobalSymbol); + std::string symobl = std::string(name_node.value()); + ConstantUpdater const_visit(symobl, ¶ms_); + const_visit(func); + return GraphAddCallNode(op, ext_func->func_name, ext_func->func_name); } diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 4475d43..cac6f55 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -44,6 +44,26 @@ namespace relay { namespace backend { /*! + * \brief A helper to expand the params by adding the ones used in a given expression. + */ +struct ConstantUpdater : public ExprVisitor { + public: + ConstantUpdater(const std::string& symbol, + std::unordered_map* params) + : symbol_(symbol), params_(params) {} + + void VisitExpr_(const ConstantNode* cn) final { + std::string name = symbol_ + "_const_" + std::to_string(const_idx_++); + (*params_)[name] = cn->data; + } + + private: + int const_idx_{0}; + std::string symbol_; + std::unordered_map* params_; +}; + +/*! * \brief A simple wrapper around ExprFunctor for a single argument case. * The result of visit is memoized. */ diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 81db341..0af1949 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -41,6 +41,7 @@ #include #include +#include "../../../target/source/codegen_source_base.h" #include "../../backend/compile_engine.h" #include "../../op/op_common.h" #include "../../transforms/pass_util.h" @@ -996,6 +997,10 @@ void VMCompiler::Codegen() { mod.CopyOnWrite(); if (target_str == "ext_dev") { + // Collect metadata in functions that are handled by external codegen. + CHECK(mod->ContainGlobalVar(cfunc->func_name)); + backend::ConstantUpdater const_visit(cfunc->func_name, ¶ms_); + const_visit(Downcast(mod->Lookup(cfunc->func_name))); continue; } else if (funcs.count(target_str) == 0) { funcs.emplace(target_str, mod); @@ -1006,29 +1011,20 @@ void VMCompiler::Codegen() { auto compile_engine = CompileEngine::Global(); auto ext_mods = compile_engine->LowerExternalFunctions(); - runtime::Module mod; if (funcs.size() > 0) { Map build_funcs; for (const auto& i : funcs) { build_funcs.Set(i.first, i.second); } - mod = tvm::build(build_funcs, target_host_); - CHECK(mod.operator->()); + exec_->lib = tvm::build(build_funcs, target_host_); } else { - CHECK_EQ(ext_mods.size(), 1U) - << "Expect to have a TVM DSOModule when multiple runtime modules exist"; + // There is no function handled by TVM. We create a virtual master module + // to make sure a DSO module will be also available. + exec_->lib = codegen::CSourceModuleCreate(";", ""); } if (!ext_mods.empty()) { - if (funcs.size() == 0) { - mod = ext_mods[0]; - } else { - // Import all external runtime modules. - for (auto it : ext_mods) { - mod.Import(it); - } - } + exec_->lib = codegen::CreateMetadataModule(params_, exec_->lib, ext_mods); } - exec_->lib = mod; } runtime::Module CreateVMCompiler() { diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc index 59bfb68..146c097 100644 --- a/src/runtime/graph/graph_runtime.cc +++ b/src/runtime/graph/graph_runtime.cc @@ -198,7 +198,7 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) { CHECK(size == names.size()) << "Invalid parameters file format"; for (size_t i = 0; i < size; ++i) { int in_idx = GetInputIndex(names[i]); - CHECK_GE(in_idx, 0) << "Found param for non-existent input: " << names[i]; + if (in_idx < 0) continue; uint32_t eid = this->entry_id(input_nodes_[in_idx], 0); CHECK_LT(eid, data_entry_.size()); @@ -222,7 +222,7 @@ void GraphRuntime::ShareParams(const GraphRuntime& other, dmlc::Stream* strm) { CHECK(size == names.size()) << "Invalid parameters file format"; for (size_t i = 0; i < size; ++i) { int in_idx = GetInputIndex(names[i]); - CHECK_GE(in_idx, 0) << "Found param for non-existent input: " << names[i]; + if (in_idx < 0) continue; uint32_t eid = this->entry_id(input_nodes_[in_idx], 0); CHECK_LT(eid, data_entry_.size()); CHECK_EQ(data_entry_[eid].use_count(), 1); @@ -422,8 +422,9 @@ PackedFunc GraphRuntime::GetFunction(const std::string& name, } else { in_idx = args[0]; } - CHECK_GE(in_idx, 0); - *rv = this->GetInput(in_idx); + if (in_idx >= 0) { + *rv = this->GetInput(in_idx); + } }); } else if (name == "get_num_outputs") { return PackedFunc( diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h index 451c0e8..03dba39 100644 --- a/src/runtime/meta_data.h +++ b/src/runtime/meta_data.h @@ -26,9 +26,12 @@ #include #include +#include +#include #include #include +#include #include #include "runtime_base.h" @@ -36,6 +39,19 @@ namespace tvm { namespace runtime { +/*! + * \brief Create a metadata module object. + * + * \param metadata The variable name to ndarray mapping. + * \param sym_vars The symbol to the list of required constant variables + * mapping. + * + * \return The created metadata module. + */ +Module MetadataModuleCreate( + const std::unordered_map& metadata, + const std::unordered_map>& sym_vars); + /*! \brief function information needed by device */ struct FunctionInfo { std::string name; diff --git a/src/runtime/metadata_module.cc b/src/runtime/metadata_module.cc new file mode 100644 index 0000000..cf3d547 --- /dev/null +++ b/src/runtime/metadata_module.cc @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/metadata_module.cc + * \brief A wrapper for initializing imported modules using metadata. This + * module is intended to be used by various runtime in the TVM stack, i.e. + * graph runtime, relay VM, AOT runtime, and various user defined runtimes. It + * paves the way to separate the code and metedata, which makes compilation + * and/or interpretation more convenient. In addition, the clear separation of + * code and metadata significantly reduces the efforts for handling external + * codegen and runtimes. + */ +#include +#include +#include +#include + +#include +#include + +#include "meta_data.h" + +namespace tvm { +namespace runtime { + +/*! + * \brief The metadata module is designed to manage initialization of the + * imported submodules. + */ +class MetadataModuleNode : public ModuleNode { + public: + MetadataModuleNode(const std::unordered_map& metadata, + const std::unordered_map>& sym_vars) + : metadata_(metadata), sym_vars_(sym_vars) {} + + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { + // Initialize and memoize the module. + // Usually, we have some warmup runs. The module initialization should be + // done at this stage. Therefore, runtime overhead is not a concern. + if (initialized_.count(name) == 0) { + this->InitSubModule(name); + initialized_.emplace(name); + } + + // Run the module. + // Normally we would only have a limited number of submodules. The runtime + // symobl lookup overhead should be minimal. + CHECK(!this->imports().empty()); + for (Module it : this->imports()) { + PackedFunc pf = it.GetFunction(name); + if (pf != nullptr) return pf; + } + return PackedFunc(nullptr); + } + + const char* type_key() const { return "metadata"; } + + /*! + * \brief Get the list of metadata that is required by the given module. + * \param symbol The symbol that is being queried. + * \return The list of needed NDArray. + */ + Array GetRequiredMetadata(const std::string& symbol) { + Array ret; + CHECK_GT(sym_vars_.count(symbol), 0U) << "No symbol is recorded for " << symbol; + std::vector vars = sym_vars_[symbol]; + for (const auto& it : vars) { + CHECK_GT(metadata_.count(it), 0U) << "Found not recorded constant variable: " << it; + ret.push_back(metadata_[it]); + } + return ret; + } + + /*! + * \brief Initialize each imported module. + * \param symobl The symbol used for initializing a module. It is also used + * for runtime lookup. + * + * \note A module could be like the following: + * MetadataModuleNode (contains all the metadata) + * - CSourceModule + * - JSON runtime module + * + * The initializer iterates through the imported modules and intilizes the + * found module accordingly by passing the needed metadata into it. + */ + void InitSubModule(const std::string& symbol) { + PackedFunc init(nullptr); + for (Module it : this->imports()) { + // Get the initialization function from the imported modules. + std::string init_name = "__init_" + symbol; + init = it.GetFunction(init_name, false); + if (init != nullptr) { + auto md = GetRequiredMetadata(symbol); + // Initialize the module with metadata. + int ret = init(md); + // Report the error if initialization is failed. + CHECK_EQ(ret, 0) << TVMGetLastError(); + break; + } + } + } + + void SaveToBinary(dmlc::Stream* stream) final { + std::vector variables; + std::vector metadata; + for (const auto& it : metadata_) { + String var_name = it.first; + variables.push_back(var_name); + metadata.push_back(it.second); + } + + // Save all variables in the function. + stream->Write(variables); + // Save all constant data. + uint64_t sz = static_cast(metadata.size()); + stream->Write(sz); + for (uint64_t i = 0; i < sz; i++) { + metadata[i].Save(stream); + } + + // Save the symbol to list of required constant variables mapping + std::vector symbols; + std::vector> const_vars; + for (const auto& it : sym_vars_) { + symbols.push_back(it.first); + const_vars.push_back(it.second); + } + + stream->Write(symbols); + sz = static_cast(sym_vars_.size()); + stream->Write(sz); + for (uint64_t i = 0; i < sz; i++) { + stream->Write(const_vars[i]); + } + } + + static Module LoadFromBinary(void* strm) { + dmlc::Stream* stream = static_cast(strm); + + // Load the variables. + std::vector variables; + CHECK(stream->Read(&variables)) << "Loading variables failed"; + uint64_t sz; + CHECK(stream->Read(&sz, sizeof(sz))) << "Loading metadata size failed"; + CHECK_EQ(static_cast(sz), variables.size()) + << "The number of variables and ndarray counts must match"; + // Load the list of ndarray. + std::vector arrays; + for (uint64_t i = 0; i < sz; i++) { + NDArray temp; + temp.Load(stream); + arrays.push_back(temp); + } + + std::unordered_map metadata; + for (uint64_t i = 0; i < sz; i++) { + CHECK_EQ(metadata.count(variables[i]), 0U); + metadata[variables[i]] = arrays[i]; + } + + // Load the symbol to list of required constant variables mapping + std::vector symbols; + CHECK(stream->Read(&symbols)) << "Loading symbols failed"; + CHECK(stream->Read(&sz, sizeof(sz))) << "Loading number of symbols failed"; + CHECK_EQ(static_cast(sz), symbols.size()); + std::vector> const_vars; + for (uint64_t i = 0; i < sz; i++) { + std::vector vars; + CHECK(stream->Read(&vars)) << "Loading const variables failed"; + const_vars.push_back(vars); + } + + std::unordered_map> sym_vars; + for (uint64_t i = 0; i < sz; i++) { + sym_vars[symbols[i]] = const_vars[i]; + } + + auto n = make_object(metadata, sym_vars); + return Module(n); + } + + private: + /*! + * \brief Record if a module is initialized. It is needed by imported + * modules using execution engine. + */ + std::unordered_set initialized_; + /*! \brief Variable name to NDArray mapping. */ + std::unordered_map metadata_; + /*! \brief Symbol name to required constant variables mapping. */ + std::unordered_map> sym_vars_; +}; + +Module MetadataModuleCreate( + const std::unordered_map& metadata, + const std::unordered_map>& sym_vars) { + auto n = make_object(metadata, sym_vars); + return Module(n); +} + +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_metadata") + .set_body_typed(MetadataModuleNode::LoadFromBinary); +} // namespace runtime +} // namespace tvm diff --git a/src/target/source/codegen_source_base.h b/src/target/source/codegen_source_base.h index 3901659..7e5e403 100644 --- a/src/target/source/codegen_source_base.h +++ b/src/target/source/codegen_source_base.h @@ -135,9 +135,26 @@ runtime::Module SourceModuleCreate(std::string code, std::string fmt); /*! * \brief Create a C source module for viewing and compiling GCC code. * \param code The code to be viewed. - * \param fmt The code. format. + * \param fmt The code format. + * \param symbol The symbol that the c source module represents. + * \param const_vars. The constant variables that the c source module needs. + * \return The created module. + */ +runtime::Module CSourceModuleCreate(const String& code, const String& fmt, + const String& symbol = "", + const Array& const_vars = {}); + +/*! + * \brief Wrap the submodules in a metadata module. + * \param params The variable to constant mapping that is collected by the host + * module. + * \param dso_module The host module to be wrapped. + * \param modules The modules to be wrapped. + * \return The wrapped module. */ -runtime::Module CSourceModuleCreate(std::string code, std::string fmt); +runtime::Module CreateMetadataModule( + const std::unordered_map& params, + const runtime::Module& dso_module, const Array& modules); /*! * \brief Create a source module for viewing and limited saving for device. diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index ba7f075..1e201e5 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -21,6 +21,7 @@ * \file source_module.cc * \brief Source code module, only for viewing */ +#include #include #include @@ -40,6 +41,44 @@ using runtime::GetFileFormat; using runtime::GetMetaFilePath; using runtime::SaveBinaryToFile; +/*! + * \brief Create a metadata module wrapper. The helper is used by different + * codegens, such as graph runtime codegen and the vm compiler. + * + * \param params The metadata for initialization of all modules. + * \param dso_module The DSO module that contains TVM primitives. + * \param modules The submodules that will be wrapped, e.g. CSource modules that + * contain vendor library calls or customized runtime modules. + * + * \return The created metadata module that manages initialization of metadata. + */ +runtime::Module CreateMetadataModule( + const std::unordered_map& params, + const runtime::Module& dso_module, const Array& modules) { + // Wrap all submodules in the initialization wrapper. + std::unordered_map> sym_metadata; + for (runtime::Module it : modules) { + CHECK_EQ(it->type_key(), "c") << "Only csource submodule is handled for now"; + String symbol = it.GetFunction("get_symbol")(); + Array variables = it.GetFunction("get_const_vars")(); + std::vector arrays; + for (size_t i = 0; i < variables.size(); i++) { + arrays.push_back(variables[i].operator std::string()); + } + CHECK_EQ(sym_metadata.count(symbol), 0U) << "Found duplicated symbol: " << symbol; + sym_metadata[symbol] = arrays; + } + + // Wrap the modules. + runtime::Module init_m = runtime::MetadataModuleCreate(params, sym_metadata); + init_m.Import(dso_module); + for (const auto& it : modules) { + init_m.Import(it); + } + + return init_m; +} + // Simulator function class SourceModuleNode : public runtime::ModuleNode { public: @@ -67,13 +106,22 @@ runtime::Module SourceModuleCreate(std::string code, std::string fmt) { // Simulator function class CSourceModuleNode : public runtime::ModuleNode { public: - CSourceModuleNode(std::string code, std::string fmt) : code_(code), fmt_(fmt) {} + CSourceModuleNode(const std::string& code, const std::string& fmt, const std::string& symbol, + const Array& const_vars) + : code_(code), fmt_(fmt), symbol_(symbol), const_vars_(const_vars) {} const char* type_key() const { return "c"; } PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { - LOG(FATAL) << "C Source module cannot execute, to get executable module" - << " build TVM with \'" << fmt_ << "\' runtime support"; - return PackedFunc(); + if (name == "get_symbol") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->symbol_; }); + } else if (name == "get_const_vars") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->const_vars_; }); + } else { + LOG(FATAL) << "Unknown packed function: " << name; + return PackedFunc(nullptr); + } } std::string GetSource(const std::string& format) final { return code_; } @@ -92,10 +140,14 @@ class CSourceModuleNode : public runtime::ModuleNode { protected: std::string code_; std::string fmt_; + std::string symbol_; + Array const_vars_; }; -runtime::Module CSourceModuleCreate(std::string code, std::string fmt) { - auto n = make_object(code, fmt); +runtime::Module CSourceModuleCreate(const String& code, const String& fmt, const String& symbol, + const Array& const_vars) { + auto n = make_object(code.operator std::string(), fmt.operator std::string(), + symbol.operator std::string(), const_vars); return runtime::Module(n); } @@ -154,6 +206,10 @@ runtime::Module DeviceSourceModuleCreate( TVM_REGISTER_GLOBAL("runtime.SourceModuleCreate").set_body_typed(SourceModuleCreate); -TVM_REGISTER_GLOBAL("runtime.CSourceModuleCreate").set_body_typed(CSourceModuleCreate); +TVM_REGISTER_GLOBAL("runtime.CSourceModuleCreate") + .set_body_typed([](String code, String fmt, String symbol, Array const_vars) { + return CSourceModuleCreate(code, fmt, symbol, const_vars); + }); + } // namespace codegen } // namespace tvm diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 665cb7b..a82f1a5 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -78,11 +78,10 @@ def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output # Its possible for some onnx inputs to not be needed in the tvm # module, confirm its present before setting. try: - m.get_input(input_names[i]) + m.set_input(input_names[i], tvm.nd.array( + input_data[i].astype(input_data[i].dtype))) except: continue - m.set_input(input_names[i], tvm.nd.array( - input_data[i].astype(input_data[i].dtype))) else: m.set_input(input_names, tvm.nd.array( input_data.astype(input_data.dtype))) diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index c449ce3..6771bd1 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -259,9 +259,52 @@ def test_extern_dnnl(): (1, 32, 14, 14), ref_res.asnumpy(), tol=1e-5) +def test_extern_dnnl_const(): + if not tvm.get_global_func("relay.ext.dnnl", True): + print("skip because DNNL codegen is not available") + return + + dtype = 'float32' + ishape = (1, 32, 14, 14) + w1shape = (32, 1, 3, 3) + data0 = relay.var('data0', shape=(ishape), dtype=dtype) + w_data = np.random.uniform(0, 1, w1shape).astype(dtype) + + data1 = relay.var('data0', shape=(ishape), dtype=dtype) + weight1 = relay.const(w_data, dtype=dtype) + weight2 = relay.const(w_data, dtype=dtype) + depthwise_conv2d_1 = relay.nn.conv2d(data1, + weight1, + kernel_size=(3, 3), + padding=(1, 1), + groups=32) + depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1, + weight2, + kernel_size=(3, 3), + padding=(1, 1), + groups=32) + out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2) + + f = relay.Function([data1], out) + ref_mod = tvm.IRModule() + ref_mod['main'] = f + + f = set_external_func_attr(f, "dnnl", "dnnl_0") + call = relay.Call(f, [data0]) + mod = tvm.IRModule.from_expr(call) + + i_data = np.random.uniform(0, 1, ishape).astype(dtype) + + ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu()) + ref_res = ref_ex.evaluate()(i_data) + check_result(mod, {"data0": i_data}, + (1, 32, 14, 14), ref_res.asnumpy(), tol=1e-5) + + if __name__ == "__main__": test_multi_node_subgraph() test_extern_gcc_single_op() test_extern_gcc_single_op_int() test_extern_gcc() test_extern_dnnl() + test_extern_dnnl_const() diff --git a/tests/python/relay/test_external_runtime.py b/tests/python/relay/test_external_runtime.py index 3920923..7928e4d 100644 --- a/tests/python/relay/test_external_runtime.py +++ b/tests/python/relay/test_external_runtime.py @@ -109,7 +109,8 @@ def generate_csource_module(): TVM_DLL_EXPORT_TYPED_FUNC(json_rt_0, ccompiler_wrapper_0_); ''' - csource_module = tvm.runtime._ffi_api.CSourceModuleCreate(code, "cc") + csource_module = tvm.runtime._ffi_api.CSourceModuleCreate(code, "cc", "", + None) return csource_module @@ -175,7 +176,8 @@ def generate_engine_module(): ''' gen_json_engine() - csource_module = tvm.runtime._ffi_api.CSourceModuleCreate(code, "cc") + csource_module = tvm.runtime._ffi_api.CSourceModuleCreate(code, "cc", "", + None) return csource_module diff --git a/tests/python/unittest/test_runtime_module_export.py b/tests/python/unittest/test_runtime_module_export.py index 8473a67..8ee197d 100644 --- a/tests/python/unittest/test_runtime_module_export.py +++ b/tests/python/unittest/test_runtime_module_export.py @@ -54,7 +54,8 @@ def generate_engine_module(): ''' import tvm.runtime._ffi_api gen_engine_header() - csource_module = tvm.runtime._ffi_api.CSourceModuleCreate(code, "cc") + csource_module = tvm.runtime._ffi_api.CSourceModuleCreate(code, "cc", "", + None) return csource_module -- 2.7.4