[Relay] Invoke tvm::build from relay compile_engine and interpreter (#4723)
authorhlu1 <14827759+hlu1@users.noreply.github.com>
Fri, 17 Jan 2020 16:58:07 +0000 (08:58 -0800)
committerZhi <5145158+zhiics@users.noreply.github.com>
Fri, 17 Jan 2020 16:58:06 +0000 (08:58 -0800)
src/relay/backend/compile_engine.cc
src/relay/backend/interpreter.cc

index 14967c1..9336782 100644 (file)
@@ -599,12 +599,13 @@ class CompileEngineImpl : public CompileEngineNode {
     CCacheValue value = LowerInternal(key);
     if (value->packed_func != nullptr) return value->packed_func;
     // build the function.
+    tvm::runtime::Module m;
     if (const auto* f = runtime::Registry::Get("relay.backend.build")) {
-      tvm::runtime::Module m = (*f)(value->cached_func->funcs, key->target);
-      value->packed_func = m.GetFunction(value->cached_func->func_name);
+      m = (*f)(value->cached_func->funcs, key->target);
     } else {
-      LOG(FATAL) << "relay.backend.build is not registered";
+      m = build(value->cached_func->funcs, key->target, Target(nullptr), BuildConfig::Current());
     }
+    value->packed_func = m.GetFunction(value->cached_func->func_name);
     return value->packed_func;
   }
 
index 68af247..ff9dbba 100644 (file)
@@ -418,13 +418,14 @@ class Interpreter :
       << "Shape function output sizes mismatch";
 
     PackedFunc shape_func;
+    Module m;
     TVMRetValue rv;
     if (const auto* f = runtime::Registry::Get("relay.backend.build")) {
-      tvm::runtime::Module m = (*f)(cfunc->funcs, cfunc->target);
-      shape_func = m.GetFunction(cfunc->func_name);
+      m = (*f)(cfunc->funcs, cfunc->target);
     } else {
-      LOG(FATAL) << "relay.backend.build is not registered";
+      m = build(cfunc->funcs, cfunc->target, Target(nullptr), BuildConfig::Current());
     }
+    shape_func = m.GetFunction(cfunc->func_name);
     shape_func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv);
 
     // Get output shapes