argument_registers.push_back(reg->second);
}
- // Next generate the invoke instruction.
Target target;
- if (targets_.size() == 1) {
- // homogeneous execution.
- for (auto kv : targets_) {
- target = kv.second;
- }
+
+ if (!func->UseDefaultCompiler()) {
+ target = tvm::target::ext_dev();
} else {
- // heterogeneous execution.
- LOG(FATAL) << "Currently VM compiler doesn't support heterogeneous compilation";
+ // Next generate the invoke instruction.
+ if (targets_.size() == 1) {
+ // homogeneous execution.
+ const auto& it = targets_.begin();
+ target = (*it).second;
+ } else {
+ // heterogeneous execution.
+ LOG(FATAL) << "Currently VM compiler doesn't support heterogeneous compilation";
+ }
}
auto key = CCacheKeyNode::make(func, target);
auto cfunc = engine_->Lower(key);
- // TODO(jroesch): support lowered funcs for multiple targets
- CHECK_EQ(cfunc->funcs.size(), 1);
auto op_index = -1;
- if (context_->seen_funcs.find(cfunc->funcs[0]) == context_->seen_funcs.end()) {
+ if (!func->UseDefaultCompiler()) {
op_index = context_->cached_funcs.size();
context_->cached_funcs.push_back(cfunc);
- context_->seen_funcs[cfunc->funcs[0]] = op_index;
} else {
- op_index = context_->seen_funcs[cfunc->funcs[0]];
+ // TODO(jroesch): support lowered funcs for multiple targets
+ CHECK_EQ(cfunc->funcs.size(), 1);
+ if (context_->seen_funcs.find(cfunc->funcs[0]) == context_->seen_funcs.end()) {
+ op_index = context_->cached_funcs.size();
+ context_->cached_funcs.push_back(cfunc);
+ context_->seen_funcs[cfunc->funcs[0]] = op_index;
+ } else {
+ op_index = context_->seen_funcs[cfunc->funcs[0]];
+ }
}
Emit(Instruction::InvokePacked(op_index,
if (cached_funcs.size() == 0) {
return;
}
- std::unordered_map<std::string, Array<LoweredFunc>> tgt_funcs;
- for (auto &cfunc : cached_funcs) {
+ std::unordered_map<std::string, Array<LoweredFunc>> funcs;
+ for (auto& cfunc : cached_funcs) {
std::string target_str = cfunc->target->str();
- if (tgt_funcs.count(target_str) == 0) {
- tgt_funcs.emplace(target_str, Array<LoweredFunc>{cfunc->funcs[0]});
+ if (target_str == "ext_dev") {
+ continue;
+ } else if (funcs.count(target_str) == 0) {
+ funcs.emplace(target_str, Array<LoweredFunc>{cfunc->funcs[0]});
} else {
- tgt_funcs[target_str].push_back(cfunc->funcs[0]);
+ funcs[target_str].push_back(cfunc->funcs[0]);
}
}
- Map<Target, Array<LoweredFunc>> funcs;
- for (auto &it : tgt_funcs) {
- funcs.Set(Target::Create(it.first), it.second);
- }
- if (const auto *f = runtime::Registry::Get("relay.backend.build")) {
- // The target is just a dummy arg because funcs already contains corresponding target
- // therefore target won't be used in the build function
- runtime::Module mod = (*f)(funcs, Target(), target_host_);
+ auto compile_engine = CompileEngine::Global();
+ auto ext_mods = compile_engine->LowerExternalFunctions();
+ runtime::Module mod;
+ if (funcs.size() > 0) {
+ mod = tvm::build(funcs, target_host_, tvm::BuildConfig::Current());
CHECK(mod.operator->());
- exec_->lib = mod;
} else {
- LOG(FATAL) << "relay.backend.build is not registered";
+ CHECK_EQ(ext_mods.size(), 1U)
+ << "Expect to have a TVM DSOModule when multiple runtime modules exist";
+ }
+ if (!ext_mods.empty()) {
+ if (funcs.size() == 0) {
+ mod = ext_mods[0];
+ } else {
+ // Import all external runtime modules.
+ for (auto it : ext_mods) {
+ mod.Import(it);
+ }
+ }
}
+ exec_->lib = mod;
size_t primitive_index = 0;
for (auto cfunc : cached_funcs) {
- exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++});
+ if (cfunc->target->str() == "ext_dev") {
+ exec_->primitive_map.insert({cfunc->func_name, primitive_index++});
+ } else {
+ exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++});
+ }
}
}
from tvm import relay
from tvm.contrib import util
-def check_result(mod, map_inputs, out_shape, result, tol=1e-5):
+def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
+ ctx=tvm.cpu()):
if sys.platform == "win32":
print("Skip test on Windows for now")
return
- with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
- json, lib, _ = relay.build(mod, "llvm")
- test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))
- source_dir = os.path.join(test_dir, "..", "..", "..")
- contrib_path = os.path.join(source_dir, "src", "runtime", "contrib")
-
- kwargs = {}
- kwargs["options"] = ["-O2", "-std=c++11", "-I" + contrib_path]
- tmp_path = util.tempdir()
- lib_name = 'lib.so'
- lib_path = tmp_path.relpath(lib_name)
- lib.export_library(lib_path, fcompile=False, **kwargs)
- lib = tvm.module.load(lib_path)
-
- ctx = tvm.cpu()
- rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx)
-
- for name, data in map_inputs.items():
- rt_mod.set_input(name, data)
-
- rt_mod.run()
- out = tvm.nd.empty(out_shape, ctx=ctx)
- out = rt_mod.get_output(0, out)
-
- tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
+ def update_lib(lib):
+ test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))
+ source_dir = os.path.join(test_dir, "..", "..", "..")
+ contrib_path = os.path.join(source_dir, "src", "runtime", "contrib")
+
+ kwargs = {}
+ kwargs["options"] = ["-O2", "-std=c++11", "-I" + contrib_path]
+ tmp_path = util.tempdir()
+ lib_name = 'lib.so'
+ lib_path = tmp_path.relpath(lib_name)
+ lib.export_library(lib_path, fcompile=False, **kwargs)
+ lib = tvm.module.load(lib_path)
+
+ return lib
+
+ def check_vm_result():
+ with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
+ exe = relay.vm.compile(mod, target=target)
+ code, lib = exe.save()
+ lib = update_lib(lib)
+ exe = relay.vm.Executable.load_exec(code, lib)
+ vm = relay.vm.VirtualMachine(exe)
+ vm.init(ctx)
+ out = vm.run(**map_inputs)
+ tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
+
+ def check_graph_runtime_result():
+ with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
+ json, lib, _ = relay.build(mod, target=target)
+ lib = update_lib(lib)
+ rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx)
+
+ for name, data in map_inputs.items():
+ rt_mod.set_input(name, data)
+ rt_mod.run()
+ out = tvm.nd.empty(out_shape, ctx=ctx)
+ out = rt_mod.get_output(0, out)
+
+ tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
+
+ check_vm_result()
+ check_graph_runtime_result()
def set_external_func_attr(func, compiler, ext_symbol):