From a2c7f52ca0e733fea8102ae4e23b2903ef676a0a Mon Sep 17 00:00:00 2001 From: Zhi <5145158+zhiics@users.noreply.github.com> Date: Thu, 5 Mar 2020 15:56:44 -0800 Subject: [PATCH] hotfix gcn tutorial fail (#4994) --- python/tvm/relay/build_module.py | 4 ++++ tutorials/frontend/build_gcn.py | 7 +++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index e894933..d1add27 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -222,6 +222,8 @@ def build(mod, target=None, target_host=None, params=None): raise ValueError("Type of input parameter mod must be tvm.IRModule") if isinstance(mod, _expr.Function): + if params: + mod = bind_params_by_name(mod, params) mod = IRModule.from_expr(mod) warnings.warn( "Please use input parameter mod (tvm.IRModule) " @@ -278,6 +280,8 @@ def optimize(mod, target=None, params=None): raise ValueError("Type of input parameter mod must be tvm.IRModule") if isinstance(mod, _expr.Function): + if params: + mod = bind_params_by_name(mod, params) mod = IRModule.from_expr(mod) warnings.warn( "Please use input parameter mod (tvm.IRModule) " diff --git a/tutorials/frontend/build_gcn.py b/tutorials/frontend/build_gcn.py index e0d0aa0..6ac518e 100644 --- a/tutorials/frontend/build_gcn.py +++ b/tutorials/frontend/build_gcn.py @@ -314,7 +314,6 @@ layers.append(GraphConv( # Analyze free variables and generate Relay function output = layers[-1] -func = relay.Function(relay.analysis.free_vars(output), output) ###################################################################### # Compile and run with TVM @@ -332,9 +331,13 @@ for i in range(num_layers+1): # Set the TVM build target target = 'llvm' # Currently only support `llvm` as target +func = relay.Function(relay.analysis.free_vars(output), output) +func = relay.build_module.bind_params_by_name(func, params) +mod = tvm.IRModule() +mod["main"] = func # Build with Relay with relay.build_config(opt_level=0): # Currently only support opt_level=0 - graph, lib, params = relay.build(func, target, params=params) + graph, lib, params = relay.build(mod, target, params=params) # Generate graph runtime ctx = tvm.context(target, 0) -- 2.7.4