Support standardize runtime module (#4532)
authorZhao Wu <wuzhaozju@gmail.com>
Sun, 22 Dec 2019 04:14:40 +0000 (12:14 +0800)
committerTianqi Chen <tqchen@users.noreply.github.com>
Sun, 22 Dec 2019 04:14:40 +0000 (20:14 -0800)
python/tvm/_ffi/function.py
python/tvm/module.py
src/codegen/codegen.cc
src/runtime/library_module.cc
src/runtime/module.cc
tests/python/unittest/test_runtime_module_export.py [new file with mode: 0644]

index 60e7aeb..ed2f7e1 100644 (file)
@@ -82,6 +82,9 @@ class ModuleBase(object):
     def __del__(self):
         check_call(_LIB.TVMModFree(self.handle))
 
+    def __hash__(self):
+        return ctypes.cast(self.handle, ctypes.c_void_p).value
+
     @property
     def entry_func(self):
         """Get the entry function
index 976fb2d..e9e2294 100644 (file)
@@ -118,31 +118,28 @@ class Module(ModuleBase):
             self.save(file_name)
             return
 
-        if not (self.type_key == "llvm" or self.type_key == "c"):
-            raise ValueError("Module[%s]: Only llvm and c support export shared" % self.type_key)
+        modules = self._collect_dso_modules()
         temp = _util.tempdir()
-        if fcompile is not None and hasattr(fcompile, "object_format"):
-            object_format = fcompile.object_format
-        else:
-            if self.type_key == "llvm":
-                object_format = "o"
+        files = []
+        is_system_lib = False
+        has_c_module = False
+        for index, module in enumerate(modules):
+            if fcompile is not None and hasattr(fcompile, "object_format"):
+                object_format = fcompile.object_format
             else:
-                assert self.type_key == "c"
-                object_format = "cc"
-        path_obj = temp.relpath("lib." + object_format)
-        self.save(path_obj)
-        files = [path_obj]
-        is_system_lib = self.type_key == "llvm" and self.get_function("__tvm_is_system_module")()
-        has_imported_c_file = False
+                if module.type_key == "llvm":
+                    object_format = "o"
+                else:
+                    assert module.type_key == "c"
+                    object_format = "cc"
+                    has_c_module = True
+            path_obj = temp.relpath("lib" + str(index) + "." + object_format)
+            module.save(path_obj)
+            files.append(path_obj)
+            is_system_lib = (module.type_key == "llvm" and
+                             module.get_function("__tvm_is_system_module")())
+
         if self.imported_modules:
-            for i, m in enumerate(self.imported_modules):
-                if m.type_key == "c":
-                    has_imported_c_file = True
-                    c_file_name = "tmp_" + str(i) + ".cc"
-                    path_cc = temp.relpath(c_file_name)
-                    with open(path_cc, "w") as f:
-                        f.write(m.get_source())
-                    files.append(path_cc)
             path_cc = temp.relpath("devc.cc")
             with open(path_cc, "w") as f:
                 f.write(_PackImportsToC(self, is_system_lib))
@@ -152,13 +149,15 @@ class Module(ModuleBase):
                 fcompile = _tar.tar
             else:
                 fcompile = _cc.create_shared
-        if self.type_key == "c" or has_imported_c_file:
+
+        if has_c_module:
             options = []
             if "options" in kwargs:
                 opts = kwargs["options"]
                 options = opts if isinstance(opts, (list, tuple)) else [opts]
             opts = options + ["-I" + path for path in find_include_path()]
             kwargs.update({'options': opts})
+
         fcompile(file_name, files, **kwargs)
 
     def time_evaluator(self, func_name, ctx, number=10, repeat=1, min_repeat_ms=0):
@@ -219,6 +218,25 @@ class Module(ModuleBase):
         except NameError:
             raise NameError("time_evaluate is only supported when RPC is enabled")
 
+    def _collect_dso_modules(self):
+        """Helper function to collect dso modules, then return it."""
+        visited, stack, dso_modules = set(), [], []
+        # append root module
+        visited.add(self)
+        stack.append(self)
+        while stack:
+            module = stack.pop()
+            if module._dso_exportable():
+                dso_modules.append(module)
+            for m in module.imported_modules:
+                if m not in visited:
+                    visited.add(m)
+                    stack.append(m)
+        return dso_modules
+
+    def _dso_exportable(self):
+        return self.type_key == "llvm" or self.type_key == "c"
+
 
 def system_lib():
     """Get system-wide library module singleton.
index 6ce76f6..60b12dc 100644 (file)
 #include <tvm/build_module.h>
 #include <dmlc/memory_io.h>
 #include <sstream>
-#include <iostream>
+#include <vector>
+#include <cstdint>
+#include <unordered_set>
+#include <cstring>
 
 namespace tvm {
 namespace codegen {
@@ -58,20 +61,111 @@ runtime::Module Build(const Array<LoweredFunc>& funcs,
   return m;
 }
 
+/*! \brief Helper class to serialize module */
+class ModuleSerializer {
+ public:
+  explicit ModuleSerializer(runtime::Module mod) : mod_(mod) {
+    Init();
+  }
+
+  void SerializeModule(dmlc::Stream* stream) {
+    // Only have one DSO module and it is in the root, then
+    // we will not produce import_tree_.
+    bool has_import_tree = true;
+    if (DSOExportable(mod_.operator->()) && mod_->imports().empty()) {
+      has_import_tree = false;
+    }
+    uint64_t sz = 0;
+    if (has_import_tree) {
+      // we will append one key for _import_tree
+      // The layout is the same as before: binary_size, key, logic, key, logic...
+      sz = mod_vec_.size() + 1;
+    } else {
+      // Keep the old behaviour
+      sz = mod_->imports().size();
+    }
+    stream->Write(sz);
+
+    for (auto m : mod_vec_) {
+      std::string mod_type_key = m->type_key();
+      if (!DSOExportable(m)) {
+        stream->Write(mod_type_key);
+        m->SaveToBinary(stream);
+      } else if (has_import_tree) {
+        mod_type_key = "_lib";
+        stream->Write(mod_type_key);
+      }
+    }
+
+    // Write _import_tree key if we have
+    if (has_import_tree) {
+      std::string import_key = "_import_tree";
+      stream->Write(import_key);
+      stream->Write(import_tree_row_ptr_);
+      stream->Write(import_tree_child_indices_);
+    }
+  }
+
+ private:
+  void Init() {
+    CreateModuleIndex();
+    CreateImportTree();
+  }
+
+  // invariance: root module is always at location 0.
+  // The module order is collected via DFS
+  void CreateModuleIndex() {
+    std::unordered_set<const runtime::ModuleNode*> visited {mod_.operator->()};
+    std::vector<runtime::ModuleNode*> stack {mod_.operator->()};
+    uint64_t module_index = 0;
+
+    while (!stack.empty()) {
+      runtime::ModuleNode* n = stack.back();
+      stack.pop_back();
+      mod2index_[n] = module_index++;
+      mod_vec_.emplace_back(n);
+      for (runtime::Module m : n->imports()) {
+        runtime::ModuleNode* next = m.operator->();
+        if (visited.count(next) == 0) {
+          visited.insert(next);
+          stack.push_back(next);
+        }
+      }
+    }
+  }
+
+  void CreateImportTree() {
+    for (auto m : mod_vec_) {
+      for (runtime::Module im : m->imports()) {
+        uint64_t mod_index = mod2index_[im.operator->()];
+        import_tree_child_indices_.push_back(mod_index);
+      }
+      import_tree_row_ptr_.push_back(import_tree_child_indices_.size());
+    }
+  }
+
+  bool DSOExportable(const runtime::ModuleNode* mod) {
+    return !std::strcmp(mod->type_key(), "llvm") ||
+           !std::strcmp(mod->type_key(), "c");
+  }
+
+  runtime::Module mod_;
+  // construct module to index
+  std::unordered_map<runtime::ModuleNode*, size_t> mod2index_;
+  // index -> module
+  std::vector<runtime::ModuleNode*> mod_vec_;
+  std::vector<uint64_t> import_tree_row_ptr_ {0};
+  std::vector<uint64_t> import_tree_child_indices_;
+};
+
 std::string PackImportsToC(const runtime::Module& mod, bool system_lib) {
   std::string bin;
   dmlc::MemoryStringStream ms(&bin);
   dmlc::Stream* stream = &ms;
-  uint64_t sz = static_cast<uint64_t>(mod->imports().size());
-  stream->Write(sz);
-  for (runtime::Module im : mod->imports()) {
-    CHECK_EQ(im->imports().size(), 0U)
-        << "Only support simply one-level hierarchy";
-    std::string tkey = im->type_key();
-    stream->Write(tkey);
-    if (tkey == "c") continue;
-    im->SaveToBinary(stream);
-  }
+
+  ModuleSerializer module_serializer(mod);
+  module_serializer.SerializeModule(stream);
+
   // translate to C program
   std::ostringstream os;
   os << "#ifdef _WIN32\n"
index d3283bc..9aaf5b9 100644 (file)
@@ -28,6 +28,7 @@
 #include <tvm/runtime/registry.h>
 #include <string>
 #include <vector>
+#include <cstdint>
 #include "library_module.h"
 
 namespace tvm {
@@ -108,9 +109,11 @@ void InitContextFunctions(std::function<void*(const char*)> fgetsymbol) {
 /*!
  * \brief Load and append module blob to module list
  * \param mblob The module blob.
- * \param module_list The module list to append to
+ * \param lib The library.
+ *
+ * \return Root Module.
  */
-void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) {
+runtime::Module ProcessModuleBlob(const char* mblob, ObjectPtr<Library> lib) {
 #ifndef _LIBCPP_SGX_CONFIG
   CHECK(mblob != nullptr);
   uint64_t nbytes = 0;
@@ -123,20 +126,56 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) {
   dmlc::Stream* stream = &fs;
   uint64_t size;
   CHECK(stream->Read(&size));
+  std::vector<Module> modules;
+  std::vector<uint64_t> import_tree_row_ptr;
+  std::vector<uint64_t> import_tree_child_indices;
   for (uint64_t i = 0; i < size; ++i) {
     std::string tkey;
     CHECK(stream->Read(&tkey));
-    if (tkey == "c") continue;
-    std::string fkey = "module.loadbinary_" + tkey;
-    const PackedFunc* f = Registry::Get(fkey);
-    CHECK(f != nullptr)
+    // Currently, _lib is for DSOModule, but we
+    // don't have loadbinary function for it currently
+    if (tkey == "_lib") {
+      auto dso_module = Module(make_object<LibraryModuleNode>(lib));
+      modules.emplace_back(dso_module);
+    } else if (tkey == "_import_tree") {
+      CHECK(stream->Read(&import_tree_row_ptr));
+      CHECK(stream->Read(&import_tree_child_indices));
+    } else {
+      std::string fkey = "module.loadbinary_" + tkey;
+      const PackedFunc* f = Registry::Get(fkey);
+      CHECK(f != nullptr)
         << "Loader of " << tkey << "("
         << fkey << ") is not presented.";
-    Module m = (*f)(static_cast<void*>(stream));
-    mlist->push_back(m);
+      Module m = (*f)(static_cast<void*>(stream));
+      modules.emplace_back(m);
+    }
   }
+  // if we are using old dll, we don't have import tree
+  // so that we can't reconstruct module relationship using import tree
+  if (import_tree_row_ptr.empty()) {
+    auto n = make_object<LibraryModuleNode>(lib);
+    auto module_import_addr = ModuleInternal::GetImportsAddr(n.operator->());
+    for (const auto& m : modules) {
+      module_import_addr->emplace_back(m);
+    }
+    return Module(n);
+  } else {
+    for (size_t i = 0; i < modules.size(); ++i) {
+      for (size_t j = import_tree_row_ptr[i]; j < import_tree_row_ptr[i + 1]; ++j) {
+        auto module_import_addr = ModuleInternal::GetImportsAddr(modules[i].operator->());
+        auto child_index = import_tree_child_indices[j];
+        CHECK(child_index < modules.size());
+        module_import_addr->emplace_back(modules[child_index]);
+      }
+    }
+  }
+  CHECK(!modules.empty());
+  // invariance: root module is always at location 0.
+  // The module order is collected via DFS
+  return modules[0];
 #else
   LOG(FATAL) << "SGX does not support ImportModuleBlob";
+  return Module();
 #endif
 }
 
@@ -149,17 +188,20 @@ Module CreateModuleFromLibrary(ObjectPtr<Library> lib) {
   const char* dev_mblob =
       reinterpret_cast<const char*>(
           lib->GetSymbol(runtime::symbol::tvm_dev_mblob));
+  Module root_mod;
   if (dev_mblob != nullptr) {
-    ImportModuleBlob(
-        dev_mblob, ModuleInternal::GetImportsAddr(n.operator->()));
+    root_mod = ProcessModuleBlob(dev_mblob, lib);
+  } else {
+    // Only have one single DSO Module
+    root_mod = Module(n);
   }
 
-  Module root_mod = Module(n);
-  // allow lookup of symbol from root(so all symbols are visible).
+  // allow lookup of symbol from root (so all symbols are visible).
   if (auto *ctx_addr =
       reinterpret_cast<void**>(lib->GetSymbol(runtime::symbol::tvm_module_ctx))) {
     *ctx_addr = root_mod.operator->();
   }
+
   return root_mod;
 }
 }  // namespace runtime
index 161675c..2f3e337 100644 (file)
@@ -115,7 +115,7 @@ const PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) {
   if (it != import_cache_.end()) return it->second.get();
   PackedFunc pf;
   for (Module& m : this->imports_) {
-    pf = m.GetFunction(name, false);
+    pf = m.GetFunction(name, true);
     if (pf != nullptr) break;
   }
   if (pf == nullptr) {
diff --git a/tests/python/unittest/test_runtime_module_export.py b/tests/python/unittest/test_runtime_module_export.py
new file mode 100644 (file)
index 0000000..b676cf2
--- /dev/null
@@ -0,0 +1,208 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from tvm import relay
+from tvm.relay import testing
+import tvm
+
+from tvm.contrib import util
+header_file_dir_path = util.tempdir()
+
+
+def gen_engine_header():
+    code = r'''
+        #ifndef _ENGINE_H_
+        #define _ENGINE_H_
+        #include <cstdint>
+        #include <string>
+        #include <sstream>
+        #include <vector>
+        class Engine {
+        };
+    
+        #endif
+        '''
+    header_file = header_file_dir_path.relpath("gcc_engine.h")
+    with open(header_file, 'w') as f:
+        f.write(code)
+
+
+def generate_engine_module():
+    code = r'''
+        #include <tvm/runtime/c_runtime_api.h>
+        #include <dlpack/dlpack.h>
+        #include "gcc_engine.h"
+    
+        extern "C" void gcc_1_(float* gcc_input4, float* gcc_input5,
+                float* gcc_input6, float* gcc_input7, float* out) {
+            Engine engine;
+        }
+        '''
+    gen_engine_header()
+    csource_module = tvm.module.csource_module_create(code, "cc")
+    return csource_module
+
+
+def test_mod_export():
+    def verify_gpu_mod_export(obj_format):
+        for device in ["llvm", "cuda"]:
+            if not tvm.module.enabled(device):
+                print("skip because %s is not enabled..." % device)
+                return
+
+        resnet18_mod, resnet18_params = relay.testing.resnet.get_workload(num_layers=18)
+        resnet50_mod, resnet50_params = relay.testing.resnet.get_workload(num_layers=50)
+        with relay.build_config(opt_level=3):
+            _, resnet18_gpu_lib, _ = relay.build_module.build(resnet18_mod, "cuda", params=resnet18_params)
+            _, resnet50_cpu_lib, _ = relay.build_module.build(resnet50_mod, "llvm", params=resnet50_params)
+
+        from tvm.contrib import util
+        temp = util.tempdir()
+        if obj_format == ".so":
+            file_name = "deploy_lib.so"
+        else:
+            assert obj_format == ".tar"
+            file_name = "deploy_lib.tar"
+        path_lib = temp.relpath(file_name)
+        resnet18_gpu_lib.imported_modules[0].import_module(resnet50_cpu_lib)
+        resnet18_gpu_lib.export_library(path_lib)
+        loaded_lib = tvm.module.load(path_lib)
+        assert loaded_lib.type_key == "library"
+        assert loaded_lib.imported_modules[0].type_key == "cuda"
+        assert loaded_lib.imported_modules[0].imported_modules[0].type_key == "library"
+
+    def verify_multi_dso_mod_export(obj_format):
+        for device in ["llvm"]:
+            if not tvm.module.enabled(device):
+                print("skip because %s is not enabled..." % device)
+                return
+
+        resnet18_mod, resnet18_params = relay.testing.resnet.get_workload(num_layers=18)
+        with relay.build_config(opt_level=3):
+            _, resnet18_cpu_lib, _ = relay.build_module.build(resnet18_mod, "llvm", params=resnet18_params)
+
+        A = tvm.placeholder((1024,), name='A')
+        B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
+        s = tvm.create_schedule(B.op)
+        f = tvm.build(s, [A, B], "llvm", name="myadd")
+        from tvm.contrib import util
+        temp = util.tempdir()
+        if obj_format == ".so":
+            file_name = "deploy_lib.so"
+        else:
+            assert obj_format == ".tar"
+            file_name = "deploy_lib.tar"
+        path_lib = temp.relpath(file_name)
+        resnet18_cpu_lib.import_module(f)
+        resnet18_cpu_lib.export_library(path_lib)
+        loaded_lib = tvm.module.load(path_lib)
+        assert loaded_lib.type_key == "library"
+        assert loaded_lib.imported_modules[0].type_key == "library"
+
+    def verify_json_import_dso(obj_format):
+        for device in ["llvm"]:
+            if not tvm.module.enabled(device):
+                print("skip because %s is not enabled..." % device)
+                return
+
+        # Get subgraph Json.
+        subgraph_json = ("json_rt_0\n" +
+                         "input 0 10 10\n" +
+                         "input 1 10 10\n" +
+                         "input 2 10 10\n" +
+                         "input 3 10 10\n" +
+                         "add 4 inputs: 0 1 shape: 10 10\n" +
+                         "sub 5 inputs: 4 2 shape: 10 10\n" +
+                         "mul 6 inputs: 5 3 shape: 10 10\n" +
+                         "json_rt_1\n" +
+                         "input 0 10 10\n" +
+                         "input 1 10 10\n" +
+                         "input 2 10 10\n" +
+                         "input 3 10 10\n" +
+                         "add 4 inputs: 0 1 shape: 10 10\n" +
+                         "sub 5 inputs: 4 2 shape: 10 10\n" +
+                         "mul 6 inputs: 5 3 shape: 10 10")
+
+        from tvm.contrib import util
+        temp = util.tempdir()
+        subgraph_path = temp.relpath('subgraph.examplejson')
+        with open(subgraph_path, 'w') as f:
+            f.write(subgraph_json)
+
+        # Get Json and module.
+        A = tvm.placeholder((1024,), name='A')
+        B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
+        s = tvm.create_schedule(B.op)
+        f = tvm.build(s, [A, B], "llvm", name="myadd")
+        try:
+            ext_lib = tvm.module.load(subgraph_path, "examplejson")
+        except:
+            print("skip because Loader of examplejson is not presented")
+            return
+        ext_lib.import_module(f)
+        if obj_format == ".so":
+            file_name = "deploy_lib.so"
+        else:
+            assert obj_format == ".tar"
+            file_name = "deploy_lib.tar"
+        path_lib = temp.relpath(file_name)
+        ext_lib.export_library(path_lib)
+        lib = tvm.module.load(path_lib)
+        assert lib.type_key == "examplejson"
+        assert lib.imported_modules[0].type_key == "library"
+
+    def verify_multi_c_mod_export():
+        from shutil import which
+        if which("gcc") is None:
+            print("Skip test because gcc is not available.")
+
+        for device in ["llvm"]:
+            if not tvm.module.enabled(device):
+                print("skip because %s is not enabled..." % device)
+                return
+
+        resnet18_mod, resnet18_params = relay.testing.resnet.get_workload(num_layers=18)
+        with relay.build_config(opt_level=3):
+            _, resnet18_cpu_lib, _ = relay.build_module.build(resnet18_mod, "llvm", params=resnet18_params)
+
+        A = tvm.placeholder((1024,), name='A')
+        B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
+        s = tvm.create_schedule(B.op)
+        f = tvm.build(s, [A, B], "c", name="myadd")
+        engine_module = generate_engine_module()
+        from tvm.contrib import util
+        temp = util.tempdir()
+        file_name = "deploy_lib.so"
+        path_lib = temp.relpath(file_name)
+        resnet18_cpu_lib.import_module(f)
+        resnet18_cpu_lib.import_module(engine_module)
+        kwargs = {"options": ["-O2", "-std=c++11", "-I" + header_file_dir_path.relpath("")]}
+        resnet18_cpu_lib.export_library(path_lib, fcompile=False, **kwargs)
+        loaded_lib = tvm.module.load(path_lib)
+        assert loaded_lib.type_key == "library"
+        assert loaded_lib.imported_modules[0].type_key == "library"
+        assert loaded_lib.imported_modules[1].type_key == "library"
+
+    for obj_format in [".so", ".tar"]:
+        verify_gpu_mod_export(obj_format)
+        verify_multi_dso_mod_export(obj_format)
+        verify_json_import_dso(obj_format)
+
+    verify_multi_c_mod_export()
+
+
+if __name__ == "__main__":
+    test_mod_export()