From 03cbf78e3e8cc910c5a9c95cd3fcafa19959644f Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Thu, 2 Apr 2020 16:29:14 +0100 Subject: [PATCH] [Frontend][Torch] Fix up graph input handling (#5204) * [Frontend][Torch] Simplify operator input handling * [Frontend][Torch] Allow user supplied input names to override graph inputs * Fix pylint issues * Updates from code review feedback * Fix tutorial to use shape list input * Disable intermittent test failure in topi vision test --- python/tvm/relay/frontend/pytorch.py | 155 +++++++++++++------------- python/tvm/relay/frontend/qnn_torch.py | 7 +- tests/python/frontend/pytorch/qnn_test.py | 7 +- tests/python/frontend/pytorch/test_forward.py | 14 +-- topi/tests/python/test_topi_vision.py | 3 + tutorials/frontend/from_pytorch.py | 7 +- 6 files changed, 99 insertions(+), 94 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 269eb4c..9a08af9 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1071,16 +1071,8 @@ def _get_input_names(node_or_graph): return [inp.debugName() for inp in node_or_graph.inputs()] -def _get_op_inputs(op_node, outputs, output_index_map): - input_names = [output_index_map[name] - for name in _get_input_names(op_node)] - return [outputs[name] for name in input_names] - - -def _update_outputs_from_pairs(name_output_pairs, outputs, output_index_map): - for output_name, output in name_output_pairs: - output_index_map[output_name] = len(outputs) - outputs.append(output) +def _get_op_inputs(op_node, outputs): + return [outputs[name] for name in _get_input_names(op_node)] def _report_missing_conversion(op_names): @@ -1100,18 +1092,31 @@ def _report_missing_conversion(op_names): raise NotImplementedError(msg) -def _check_input_names(script_module, input_shapes): - """ Check the graph inputs match the inputs """ - ir_inputs = get_graph_input_names(script_module) - - for ir_input in ir_inputs: - if ir_input not in input_shapes: - msg = "Missing graph input {} in input_shapes".format(ir_input) - raise RuntimeError(msg) - - for input_name in input_shapes: - if input_name not in ir_inputs: - msg = "Unused graph input {} in input_shapes".format(input_name) +def _check_inputs(graph, input_shapes): + """ + Check the graph inputs match the expected number of inputs + and are in the correct format + """ + ir_inputs = _get_graph_input_names(graph) + + if not isinstance(input_shapes, list): + msg = "Graph inputs input_shapes should be list" + raise RuntimeError(msg) + missing_inputs = len(ir_inputs) - len(input_shapes) + if missing_inputs > 0: + msg = "Missing {} graph input(s) in input_shapes".format(missing_inputs) + raise RuntimeError(msg) + + for num, inp in enumerate(input_shapes): + if num < len(ir_inputs): + if not isinstance(inp, tuple): + msg = "Graph input {} is not a tuple".format(num) + raise RuntimeError(msg) + if (len(inp) != 2 or not isinstance(inp[0], str)): + msg = "Graph input {} is not valid, expected ('name', shape)".format(inp) + raise RuntimeError(msg) + else: + msg = "Unused graph input {} in input_shapes".format(inp) logging.warning(msg) @@ -1203,10 +1208,19 @@ def _get_operator_nodes(nodes): return ops -def _get_relay_input_vars(input_shapes): - """ Return Relay vars from input shapes """ - return {iname: _expr.var(iname, shape=ishape) - for iname, ishape in input_shapes.items()} +def _get_relay_input_vars(graph, input_shapes): + """ + Return Relay vars from input shapes and create entries based on + expected graph inputs - to allow translation + """ + input_vars = {} + ir_inputs = _get_graph_input_names(graph) + for ir_input, (name, shape) in zip(ir_inputs, input_shapes): + inp = _expr.var(name, shape=shape) + # Translate from graph input to user input name + input_vars[ir_input] = inp + + return input_vars def get_use_chains(root_node, terminate=lambda _: False): @@ -1284,24 +1298,24 @@ def convert_params(graph, state_dict): return params, param_tensors, packed_param_map -def convert_block(block, outputs, output_index_map): +def convert_block(block, outputs): """ Translate Torch "Block", used for prim::If and prim::Loop """ ops = _get_operator_nodes(block.nodes()) ret_names = _get_input_names(block.returnNode()) - return convert_operators(ops, outputs, output_index_map, ret_names) + return convert_operators(ops, outputs, ret_names) -def convert_if(if_node, outputs, output_index_map): +def convert_if(if_node, outputs): """ Translate Torch prim::If to Relay If """ - cond = outputs[output_index_map[if_node.inputsAt(0).debugName()]] + cond = outputs[if_node.inputsAt(0).debugName()] blocks = list(if_node.blocks()) - true_branch = convert_block(blocks[0], outputs, output_index_map) - false_branch = convert_block(blocks[1], outputs, output_index_map) + true_branch = convert_block(blocks[0], outputs) + false_branch = convert_block(blocks[1], outputs) assert len(true_branch) == 1 and len(false_branch) == 1 return _expr.If(cond, true_branch[0], false_branch[0]) -def convert_loop(loop_node, outputs, output_index_map): +def convert_loop(loop_node, outputs): """ Translate Torch prim::Loop to Relay while_loop """ def get_input(index): ivalue = loop_node.inputsAt(index) @@ -1309,8 +1323,8 @@ def convert_loop(loop_node, outputs, output_index_map): if inode.kind() == "prim::Constant": return _expr.const(_get_constant(inode)) var_name = ivalue.debugName() - assert var_name in output_index_map - return _wrap_const(outputs[output_index_map[var_name]]) + assert var_name in outputs + return _wrap_const(outputs[var_name]) # Refer to the spec for prim::Loop below # https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#loops @@ -1342,9 +1356,9 @@ def convert_loop(loop_node, outputs, output_index_map): # Update loop variables using the prev iteration outputs assert len(current_vals) == len(block_input_names) for (i, iname) in enumerate(block_input_names): - outputs[output_index_map[iname]] = current_vals[i] + outputs[iname] = current_vals[i] - block_outputs = convert_block(body_block, outputs, output_index_map) + block_outputs = convert_block(body_block, outputs) if not is_while_loop: # iter var increment implicit in torch, so do it manually @@ -1374,7 +1388,7 @@ def convert_loop(loop_node, outputs, output_index_map): name_val_pairs = list(zip(block_input_names, [init_loop_iter_val] + init_vals)) - _update_outputs_from_pairs(name_val_pairs, outputs, output_index_map) + outputs.update(name_val_pairs) loop_iter_var = _expr.var(block_input_names[0], shape=(), dtype=loop_iter_dtype) @@ -1386,36 +1400,30 @@ def convert_loop(loop_node, outputs, output_index_map): return [_expr.TupleGetItem(loop_val, i+1) for i in range(num_loop_var)] -def convert_operators(operators, outputs, output_index_map, ret_names): +def convert_operators(operators, outputs, ret_names): """ Convert each Torch IR operators to Relay equivalent """ for node_name, op_node in operators: operator = op_node.kind() - inputs = _get_op_inputs(op_node, outputs, output_index_map) + inputs = _get_op_inputs(op_node, outputs) if operator == "prim::Constant": - output_index_map[node_name] = len(outputs) - outputs.append(_get_constant(op_node)) + outputs[node_name] = _get_constant(op_node) elif operator == 'prim::ListConstruct' and _is_int_seq(inputs): - output_index_map[node_name] = len(outputs) - outputs.append(_expr.var(node_name, shape=inputs)) + outputs[node_name] = _expr.var(node_name, shape=inputs) elif operator in ['prim::ListConstruct', 'prim::TupleConstruct']: - output_index_map[node_name] = len(outputs) - outputs.append(inputs) + outputs[node_name] = inputs elif operator in ["prim::ListUnpack", 'prim::TupleUnpack']: assert len(inputs) == 1 unpacked_names = _get_output_names(op_node) - _update_outputs_from_pairs(zip(unpacked_names, inputs[0]), - outputs, output_index_map) + outputs.update(zip(unpacked_names, inputs[0])) elif operator == "prim::If": - if_out = convert_if(op_node, outputs, output_index_map) - output_index_map[node_name] = len(outputs) - outputs.append(if_out) + if_out = convert_if(op_node, outputs) + outputs[node_name] = if_out elif operator == "prim::Loop": - loop_out = convert_loop(op_node, outputs, output_index_map) + loop_out = convert_loop(op_node, outputs) unpacked_names = _get_output_names(op_node) assert len(loop_out) == len(unpacked_names) - _update_outputs_from_pairs(zip(unpacked_names, loop_out), - outputs, output_index_map) + outputs.update(zip(unpacked_names, loop_out)) else: relay_op = _convert_map[operator] relay_out = relay_op(inputs, _get_input_types(op_node)) @@ -1424,13 +1432,11 @@ def convert_operators(operators, outputs, output_index_map, ret_names): # This is for torch operators that return multiple outputs # See _adaptive_max_2d above for example out_names = _get_output_names(op_node) - _update_outputs_from_pairs(zip(out_names, relay_out), - outputs, output_index_map) + outputs.update(zip(out_names, relay_out)) else: - output_index_map[node_name] = len(outputs) - outputs.append(relay_out) + outputs[node_name] = relay_out - return [_wrap_const(outputs[output_index_map[ret_name]]) + return [_wrap_const(outputs[ret_name]) for ret_name in ret_names] @@ -1446,11 +1452,11 @@ def get_all_op_names(graph): return set(node.kind() for node in nodes) -def get_graph_input_names(script_module): - """ Use this function to set the keys for input_shapes""" - # It seems variable names could change the first time a copy is made - # Use the copy of the graph here to prevent troubles later - ir_inputs = _get_input_names(script_module.graph.copy()) +def _get_graph_input_names(graph): + """ Get the graph input names (use after graph copy and run jit passes) """ + # Variable names could change the first time a copy is made and after + # _run_jit_passes is called, expected that those functions already invoked + ir_inputs = _get_input_names(graph) return ir_inputs[1:] # remove self at the 0th arg @@ -1464,9 +1470,10 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None): TorchScripted PyTorch graph Note: We currently only support traces (ie: torch.jit.trace(model, input)) - input_shapes : Dictionary of input dimensions - Graph level input shape dictionary - The keys should be the same one returned by get_graph_input_names(...) above + input_shapes : List of tuples of input name and input dimensions + Graph level input shape list + The same input names need to be used for deployment, so choose easy to + remember names (such as: input0, input1) custom_convert_map: Dictionary of str to Relay op A custom op conversion map in the same format as _convert_map above @@ -1487,30 +1494,28 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None): op_names = get_all_op_names(graph) _report_missing_conversion(op_names) - _check_input_names(script_module, input_shapes) + _check_inputs(graph, input_shapes) params = script_module.state_dict() - input_vars = _get_relay_input_vars(input_shapes) + outputs = _get_relay_input_vars(graph, input_shapes) param_vars, tensors, packed_param_map = convert_params(graph, params) tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()} - input_vars.update(param_vars) - outputs = list(input_vars.values()) - output_index_map = dict(zip(input_vars.keys(), range(len(outputs)))) + outputs.update(param_vars) ret_name = _get_input_names(graph.return_node()) # For quantized models if "aten::quantize_per_tensor" in op_names: weight_quant_params = qnn_torch.get_weight_quant_params(script_module) qnn_torch.add_input_quant_params_to_op_inputs(graph) - qnn_torch.add_quant_params_to_outputs(outputs, output_index_map, + qnn_torch.add_quant_params_to_outputs(outputs, packed_param_map, weight_quant_params) qnn_torch.add_quant_params(tvm_params, weight_quant_params) _convert_map.update(qnn_torch.convert_map) - ret = convert_operators(_get_operator_nodes(graph.nodes()), outputs, - output_index_map, ret_name) + ret = convert_operators(_get_operator_nodes(graph.nodes()), + outputs, ret_name) if isinstance(ret[0], list): ret[0] = _expr.Tuple(ret[0]) diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index e6a015f..fb90649 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -101,20 +101,19 @@ def get_weight_quant_params(script_module): return quant_params -def add_quant_params_to_outputs(outputs, output_index_map, - packed_param_map, quant_params): +def add_quant_params_to_outputs(outputs, packed_param_map, + quant_params): """ Add quant params to outputs so that they can be referenced by other ops later. Weights are quantized here. """ for node_name, packed_param_name in packed_param_map.items(): qparam = quant_params[packed_param_name] - output_index_map[node_name] = len(outputs) qweight = relay.qnn.op.quantize(qparam.weight_var, qparam.scale, qparam.zero_point, out_dtype="int8", axis=0) param_tup = (qweight, qparam.scale, qparam.zero_point, qparam.bias_var) - outputs.append(param_tup) + outputs[node_name] = param_tup def _get_quant_param_for_input(input_value): diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index 6cd7c1f..82e3393 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -28,7 +28,6 @@ from torch.quantization import fuse_modules, QuantWrapper import tvm from tvm import relay -from tvm.relay.frontend.pytorch import get_graph_input_names from tvm.contrib.download import download_testdata @@ -39,7 +38,7 @@ def torch_version_check(): def get_tvm_runtime(script_module, input_name, ishape): - input_shapes = {input_name: ishape} + input_shapes = [(input_name, ishape)] mod, params = relay.frontend.from_pytorch(script_module, input_shapes) with relay.build_config(opt_level=3): @@ -287,7 +286,7 @@ def test_quantized_modules(): with torch.no_grad(): pt_result = script_module(inp.clone()).numpy() - input_name = get_graph_input_names(script_module)[0] + input_name = "input" runtime = get_tvm_runtime(script_module, input_name, ishape) runtime.set_input(input_name, inp.numpy().copy()) runtime.run() @@ -383,7 +382,7 @@ def test_quantized_imagenet(): with torch.no_grad(): pt_result = script_module(pt_inp).numpy() - input_name = get_graph_input_names(script_module)[0] + input_name = "image" runtime = get_tvm_runtime(script_module, input_name, (1, 3, 224, 224)) runtime.set_input(input_name, inp) runtime.run() diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 1f083cb..c75ae6e 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -28,7 +28,6 @@ import torchvision from tvm import relay from tvm.contrib import graph_runtime from tvm.relay.testing.config import ctx_list -from tvm.relay.frontend.pytorch import get_graph_input_names sys.setrecursionlimit(10000) @@ -169,8 +168,8 @@ def verify_model(model_name, input_data=[], else: trace = trace.cpu() - input_names = get_graph_input_names(trace) - input_shapes = dict(zip(input_names, + input_names = ["input{}".format(idx) for idx, inp in enumerate(baseline_input)] + input_shapes = list(zip(input_names, [inp.shape for inp in baseline_input])) mod, params = relay.frontend.from_pytorch(trace, input_shapes, custom_convert_map) @@ -888,11 +887,12 @@ def test_3d_models(): def verify_script_model(pt_model, ishapes): script_module = torch.jit.script(pt_model) - input_names = get_graph_input_names(script_module) - input_shapes = dict(zip(input_names, ishapes)) - inputs = [torch.randn(input_shapes[input_name], dtype=torch.float) - for input_name in input_names] + input_names = ["i{}".format(idx) for idx, ish in enumerate(ishapes)] + input_shapes = list(zip(input_names, ishapes)) + + inputs = [torch.randn(shape, dtype=torch.float) + for shape in ishapes] mod, params = relay.frontend.from_pytorch(script_module, input_shapes) diff --git a/topi/tests/python/test_topi_vision.py b/topi/tests/python/test_topi_vision.py index 0aa410d..fe94a4c 100644 --- a/topi/tests/python/test_topi_vision.py +++ b/topi/tests/python/test_topi_vision.py @@ -103,11 +103,14 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index): tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3) tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3) + """ Skip this test as it is intermittent + see https://github.com/apache/incubator-tvm/pull/4901#issuecomment-595040094 for device in ['llvm', 'cuda', 'opencl']: # Disable opencl test for now if device != "llvm" and device != "cuda": continue check_device(device) + """ def test_get_valid_counts(): diff --git a/tutorials/frontend/from_pytorch.py b/tutorials/frontend/from_pytorch.py index 1c568ce..45e3cb8 100644 --- a/tutorials/frontend/from_pytorch.py +++ b/tutorials/frontend/from_pytorch.py @@ -47,7 +47,6 @@ from tvm import relay import numpy as np from tvm.contrib.download import download_testdata -from tvm.relay.frontend.pytorch import get_graph_input_names # PyTorch imports import torch @@ -90,10 +89,10 @@ img = np.expand_dims(img, 0) # Import the graph to Relay # ------------------------- # Convert PyTorch graph to Relay graph. -input_name = get_graph_input_names(scripted_model)[0] # only one input -shape_dict = {input_name: img.shape} +input_name = 'input0' # only one input, set it to this name +shape_list = [(input_name, img.shape)] mod, params = relay.frontend.from_pytorch(scripted_model, - shape_dict) + shape_list) ###################################################################### # Relay Build -- 2.7.4