hotfix gcn tutorial fail (#4994)
authorZhi <5145158+zhiics@users.noreply.github.com>
Thu, 5 Mar 2020 23:56:44 +0000 (15:56 -0800)
committerGitHub <noreply@github.com>
Thu, 5 Mar 2020 23:56:44 +0000 (15:56 -0800)
python/tvm/relay/build_module.py
tutorials/frontend/build_gcn.py

index e894933..d1add27 100644 (file)
@@ -222,6 +222,8 @@ def build(mod, target=None, target_host=None, params=None):
         raise ValueError("Type of input parameter mod must be tvm.IRModule")
 
     if isinstance(mod, _expr.Function):
+        if params:
+            mod = bind_params_by_name(mod, params)
         mod = IRModule.from_expr(mod)
         warnings.warn(
             "Please use input parameter mod (tvm.IRModule) "
@@ -278,6 +280,8 @@ def optimize(mod, target=None, params=None):
         raise ValueError("Type of input parameter mod must be tvm.IRModule")
 
     if isinstance(mod, _expr.Function):
+        if params:
+            mod = bind_params_by_name(mod, params)
         mod = IRModule.from_expr(mod)
         warnings.warn(
             "Please use input parameter mod (tvm.IRModule) "
index e0d0aa0..6ac518e 100644 (file)
@@ -314,7 +314,6 @@ layers.append(GraphConv(
 
 # Analyze free variables and generate Relay function
 output = layers[-1]
-func = relay.Function(relay.analysis.free_vars(output), output)
 
 ######################################################################
 # Compile and run with TVM
@@ -332,9 +331,13 @@ for i in range(num_layers+1):
 # Set the TVM build target
 target = 'llvm' # Currently only support `llvm` as target
 
+func = relay.Function(relay.analysis.free_vars(output), output)
+func = relay.build_module.bind_params_by_name(func, params)
+mod = tvm.IRModule()
+mod["main"] = func
 # Build with Relay
 with relay.build_config(opt_level=0): # Currently only support opt_level=0
-    graph, lib, params = relay.build(func, target, params=params)
+    graph, lib, params = relay.build(mod, target, params=params)
 
 # Generate graph runtime
 ctx = tvm.context(target, 0)