From 3279957f8faec099e8a25162b2d14e67a111c210 Mon Sep 17 00:00:00 2001 From: hlu1 <14827759+hlu1@users.noreply.github.com> Date: Fri, 17 Jan 2020 08:58:07 -0800 Subject: [PATCH] [Relay] Invoke tvm::build from relay compile_engine and interpreter (#4723) --- src/relay/backend/compile_engine.cc | 7 ++++--- src/relay/backend/interpreter.cc | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 14967c1..9336782 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -599,12 +599,13 @@ class CompileEngineImpl : public CompileEngineNode { CCacheValue value = LowerInternal(key); if (value->packed_func != nullptr) return value->packed_func; // build the function. + tvm::runtime::Module m; if (const auto* f = runtime::Registry::Get("relay.backend.build")) { - tvm::runtime::Module m = (*f)(value->cached_func->funcs, key->target); - value->packed_func = m.GetFunction(value->cached_func->func_name); + m = (*f)(value->cached_func->funcs, key->target); } else { - LOG(FATAL) << "relay.backend.build is not registered"; + m = build(value->cached_func->funcs, key->target, Target(nullptr), BuildConfig::Current()); } + value->packed_func = m.GetFunction(value->cached_func->func_name); return value->packed_func; } diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 68af247..ff9dbba 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -418,13 +418,14 @@ class Interpreter : << "Shape function output sizes mismatch"; PackedFunc shape_func; + Module m; TVMRetValue rv; if (const auto* f = runtime::Registry::Get("relay.backend.build")) { - tvm::runtime::Module m = (*f)(cfunc->funcs, cfunc->target); - shape_func = m.GetFunction(cfunc->func_name); + m = (*f)(cfunc->funcs, cfunc->target); } else { - LOG(FATAL) << "relay.backend.build is not registered"; + m = build(cfunc->funcs, cfunc->target, Target(nullptr), BuildConfig::Current()); } + shape_func = m.GetFunction(cfunc->func_name); shape_func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv); // Get output shapes -- 2.7.4