vm external codegen (#4544)
authorZhi <5145158+zhiics@users.noreply.github.com>
Fri, 20 Dec 2019 22:36:14 +0000 (14:36 -0800)
committerHaichen Shen <shenhaichen@gmail.com>
Fri, 20 Dec 2019 22:36:14 +0000 (14:36 -0800)
src/relay/backend/vm/compiler.cc
src/runtime/vm/vm.cc
tests/python/relay/test_external_codegen.py

index c38ca1a..137d60b 100644 (file)
@@ -476,30 +476,39 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
       argument_registers.push_back(reg->second);
     }
 
-    // Next generate the invoke instruction.
     Target target;
-    if (targets_.size() == 1) {
-      // homogeneous execution.
-      for (auto kv : targets_) {
-        target = kv.second;
-      }
+
+    if (!func->UseDefaultCompiler()) {
+      target = tvm::target::ext_dev();
     } else {
-      // heterogeneous execution.
-      LOG(FATAL) << "Currently VM compiler doesn't support heterogeneous compilation";
+      // Next generate the invoke instruction.
+      if (targets_.size() == 1) {
+        // homogeneous execution.
+        const auto& it = targets_.begin();
+        target = (*it).second;
+      } else {
+        // heterogeneous execution.
+        LOG(FATAL) << "Currently VM compiler doesn't support heterogeneous compilation";
+      }
     }
 
     auto key = CCacheKeyNode::make(func, target);
     auto cfunc = engine_->Lower(key);
 
-    // TODO(jroesch): support lowered funcs for multiple targets
-    CHECK_EQ(cfunc->funcs.size(), 1);
     auto op_index = -1;
-    if (context_->seen_funcs.find(cfunc->funcs[0]) == context_->seen_funcs.end()) {
+    if (!func->UseDefaultCompiler()) {
       op_index = context_->cached_funcs.size();
       context_->cached_funcs.push_back(cfunc);
-      context_->seen_funcs[cfunc->funcs[0]] = op_index;
     } else {
-      op_index = context_->seen_funcs[cfunc->funcs[0]];
+      // TODO(jroesch): support lowered funcs for multiple targets
+      CHECK_EQ(cfunc->funcs.size(), 1);
+      if (context_->seen_funcs.find(cfunc->funcs[0]) == context_->seen_funcs.end()) {
+        op_index = context_->cached_funcs.size();
+        context_->cached_funcs.push_back(cfunc);
+        context_->seen_funcs[cfunc->funcs[0]] = op_index;
+      } else {
+        op_index = context_->seen_funcs[cfunc->funcs[0]];
+      }
     }
 
     Emit(Instruction::InvokePacked(op_index,
@@ -950,32 +959,46 @@ void VMCompiler::LibraryCodegen() {
   if (cached_funcs.size() == 0) {
     return;
   }
-  std::unordered_map<std::string, Array<LoweredFunc>> tgt_funcs;
-  for (auto &cfunc : cached_funcs) {
+  std::unordered_map<std::string, Array<LoweredFunc>> funcs;
+  for (autocfunc : cached_funcs) {
     std::string target_str = cfunc->target->str();
-    if (tgt_funcs.count(target_str) == 0) {
-      tgt_funcs.emplace(target_str, Array<LoweredFunc>{cfunc->funcs[0]});
+    if (target_str == "ext_dev") {
+      continue;
+    } else if (funcs.count(target_str) == 0) {
+      funcs.emplace(target_str, Array<LoweredFunc>{cfunc->funcs[0]});
     } else {
-      tgt_funcs[target_str].push_back(cfunc->funcs[0]);
+      funcs[target_str].push_back(cfunc->funcs[0]);
     }
   }
-  Map<Target, Array<LoweredFunc>> funcs;
-  for (auto &it : tgt_funcs) {
-    funcs.Set(Target::Create(it.first), it.second);
-  }
 
-  if (const auto *f = runtime::Registry::Get("relay.backend.build")) {
-    // The target is just a dummy arg because funcs already contains corresponding target
-    // therefore target won't be used in the build function
-    runtime::Module mod = (*f)(funcs, Target(), target_host_);
+  auto compile_engine = CompileEngine::Global();
+  auto ext_mods = compile_engine->LowerExternalFunctions();
+  runtime::Module mod;
+  if (funcs.size() > 0) {
+    mod = tvm::build(funcs, target_host_, tvm::BuildConfig::Current());
     CHECK(mod.operator->());
-    exec_->lib = mod;
   } else {
-    LOG(FATAL) << "relay.backend.build is not registered";
+    CHECK_EQ(ext_mods.size(), 1U)
+        << "Expect to have a TVM DSOModule when multiple runtime modules exist";
+  }
+  if (!ext_mods.empty()) {
+    if (funcs.size() == 0) {
+      mod = ext_mods[0];
+    } else {
+      // Import all external runtime modules.
+      for (auto it : ext_mods) {
+        mod.Import(it);
+      }
+    }
   }
+  exec_->lib = mod;
   size_t primitive_index = 0;
   for (auto cfunc : cached_funcs) {
-    exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++});
+    if (cfunc->target->str() == "ext_dev") {
+      exec_->primitive_map.insert({cfunc->func_name, primitive_index++});
+    } else {
+      exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++});
+    }
   }
 }
 
index 41fe71a..a3b11d4 100644 (file)
@@ -800,7 +800,9 @@ void VirtualMachine::LoadExecutable(const Executable* exec) {
     if (packed_funcs_.size() <= packed_index) {
       packed_funcs_.resize(packed_index + 1);
     }
-    packed_funcs_[packed_index] = lib.GetFunction(packed_name);
+    tvm::runtime::PackedFunc pf = lib.GetFunction(packed_name, true);
+    CHECK(pf != nullptr) << "Cannot find function in module: " << packed_name;
+    packed_funcs_[packed_index] = pf;
   }
 }
 
index fb0a8a2..2cf32e7 100644 (file)
@@ -26,36 +26,54 @@ import tvm.relay.transform
 from tvm import relay
 from tvm.contrib import util
 
-def check_result(mod, map_inputs, out_shape, result, tol=1e-5):
+def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
+                 ctx=tvm.cpu()):
     if sys.platform == "win32":
         print("Skip test on Windows for now")
         return
 
-    with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
-        json, lib, _ = relay.build(mod, "llvm")
-    test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))
-    source_dir = os.path.join(test_dir, "..", "..", "..")
-    contrib_path = os.path.join(source_dir, "src", "runtime", "contrib")
-
-    kwargs = {}
-    kwargs["options"] = ["-O2", "-std=c++11", "-I" + contrib_path]
-    tmp_path = util.tempdir()
-    lib_name = 'lib.so'
-    lib_path = tmp_path.relpath(lib_name)
-    lib.export_library(lib_path, fcompile=False, **kwargs)
-    lib = tvm.module.load(lib_path)
-
-    ctx = tvm.cpu()
-    rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx)
-
-    for name, data in map_inputs.items():
-        rt_mod.set_input(name, data)
-
-    rt_mod.run()
-    out = tvm.nd.empty(out_shape, ctx=ctx)
-    out = rt_mod.get_output(0, out)
-
-    tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
+    def update_lib(lib):
+        test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))
+        source_dir = os.path.join(test_dir, "..", "..", "..")
+        contrib_path = os.path.join(source_dir, "src", "runtime", "contrib")
+
+        kwargs = {}
+        kwargs["options"] = ["-O2", "-std=c++11", "-I" + contrib_path]
+        tmp_path = util.tempdir()
+        lib_name = 'lib.so'
+        lib_path = tmp_path.relpath(lib_name)
+        lib.export_library(lib_path, fcompile=False, **kwargs)
+        lib = tvm.module.load(lib_path)
+
+        return lib
+
+    def check_vm_result():
+        with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
+            exe = relay.vm.compile(mod, target=target)
+        code, lib = exe.save()
+        lib = update_lib(lib)
+        exe = relay.vm.Executable.load_exec(code, lib)
+        vm = relay.vm.VirtualMachine(exe)
+        vm.init(ctx)
+        out = vm.run(**map_inputs)
+        tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
+
+    def check_graph_runtime_result():
+        with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
+            json, lib, _ = relay.build(mod, target=target)
+        lib = update_lib(lib)
+        rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx)
+
+        for name, data in map_inputs.items():
+            rt_mod.set_input(name, data)
+        rt_mod.run()
+        out = tvm.nd.empty(out_shape, ctx=ctx)
+        out = rt_mod.get_output(0, out)
+
+        tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
+
+    check_vm_result()
+    check_graph_runtime_result()
 
 
 def set_external_func_attr(func, compiler, ext_symbol):