From d2799915db87107c83ef105a2a628fc54b1cada4 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 4 Feb 2020 09:32:45 -0800 Subject: [PATCH] [AutoTVM] Minor bug fixes in AutoTVM for QNN graphs (#4797) * [AutoTVM] Minor bug fixes in AutoTVM for QNN graphs. * Bring back strided_slice. * Replace tvm.nd change. --- python/tvm/autotvm/graph_tuner/utils/traverse_graph.py | 7 ++++--- python/tvm/autotvm/task/relay_integration.py | 7 +++++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py index 19c3193..f58dd28 100644 --- a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py +++ b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py @@ -126,10 +126,10 @@ def _expr2graph_impl(expr, target_ops, node_dict, node_list): for i, input_idx in enumerate(node_entry["inputs"]): input_node_entry = node_list[input_idx[0]] input_type = input_node_entry["types"][input_idx[1]] - if not isinstance(input_node_entry["node"], (Var, Call)): + if not isinstance(input_node_entry["node"], (Var, Constant, Call)): raise RuntimeError("Graph tuner can only tune target " "operators with input node of type " - "relay.expr.Var or relay.expr.Call. Now " + "relay.expr.Var/Constant/Call. Now " "find a target op %s with input type %s" % (op_name, str(type(input_node_entry["node"])))) free_var = relay.Var("var_%d" % i, input_type) @@ -167,7 +167,8 @@ def _expr2graph_impl(expr, target_ops, node_dict, node_list): else: node_entry["inputs"].append([in_node_idx, 0, 0]) elif isinstance(node, Constant): - pass + node_entry["name"] = "Constant_" + str(node_index) + node_entry["types"] = [node.checked_type] elif isinstance(node, relay.op.op.Op): return else: diff --git a/python/tvm/autotvm/task/relay_integration.py b/python/tvm/autotvm/task/relay_integration.py index 55763af..3eb1f1d 100644 --- a/python/tvm/autotvm/task/relay_integration.py +++ b/python/tvm/autotvm/task/relay_integration.py @@ -50,7 +50,8 @@ def _lower(mod, grc.codegen(mod["main"]) # default case compiler = relay.vm.VMCompiler() - compiler.set_params(params) + if params: + compiler.set_params(params) compiler.lower(mod, target=target) @@ -123,7 +124,9 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None, # relay op -> topi compute OP2TOPI = { tvm.relay.op.nn.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw, - topi.nn.group_conv2d_nchw, topi.nn.conv2d_NCHWc], + topi.nn.group_conv2d_nchw, + topi.nn.conv2d_NCHWc, + topi.nn.conv2d_NCHWc_int8], tvm.relay.op.nn.conv2d_transpose: [topi.nn.conv2d_transpose_nchw], tvm.relay.op.nn.dense: [topi.nn.dense], tvm.relay.op.nn.batch_matmul: [topi.nn.batch_matmul], -- 2.7.4