return [inp.debugName() for inp in node_or_graph.inputs()]
-def _get_op_inputs(op_node, outputs, output_index_map):
- input_names = [output_index_map[name]
- for name in _get_input_names(op_node)]
- return [outputs[name] for name in input_names]
-
-
-def _update_outputs_from_pairs(name_output_pairs, outputs, output_index_map):
- for output_name, output in name_output_pairs:
- output_index_map[output_name] = len(outputs)
- outputs.append(output)
+def _get_op_inputs(op_node, outputs):
+ return [outputs[name] for name in _get_input_names(op_node)]
def _report_missing_conversion(op_names):
raise NotImplementedError(msg)
-def _check_input_names(script_module, input_shapes):
- """ Check the graph inputs match the inputs """
- ir_inputs = get_graph_input_names(script_module)
-
- for ir_input in ir_inputs:
- if ir_input not in input_shapes:
- msg = "Missing graph input {} in input_shapes".format(ir_input)
- raise RuntimeError(msg)
-
- for input_name in input_shapes:
- if input_name not in ir_inputs:
- msg = "Unused graph input {} in input_shapes".format(input_name)
+def _check_inputs(graph, input_shapes):
+ """
+ Check the graph inputs match the expected number of inputs
+ and are in the correct format
+ """
+ ir_inputs = _get_graph_input_names(graph)
+
+ if not isinstance(input_shapes, list):
+ msg = "Graph inputs input_shapes should be list"
+ raise RuntimeError(msg)
+ missing_inputs = len(ir_inputs) - len(input_shapes)
+ if missing_inputs > 0:
+ msg = "Missing {} graph input(s) in input_shapes".format(missing_inputs)
+ raise RuntimeError(msg)
+
+ for num, inp in enumerate(input_shapes):
+ if num < len(ir_inputs):
+ if not isinstance(inp, tuple):
+ msg = "Graph input {} is not a tuple".format(num)
+ raise RuntimeError(msg)
+ if (len(inp) != 2 or not isinstance(inp[0], str)):
+ msg = "Graph input {} is not valid, expected ('name', shape)".format(inp)
+ raise RuntimeError(msg)
+ else:
+ msg = "Unused graph input {} in input_shapes".format(inp)
logging.warning(msg)
return ops
-def _get_relay_input_vars(input_shapes):
- """ Return Relay vars from input shapes """
- return {iname: _expr.var(iname, shape=ishape)
- for iname, ishape in input_shapes.items()}
+def _get_relay_input_vars(graph, input_shapes):
+ """
+ Return Relay vars from input shapes and create entries based on
+ expected graph inputs - to allow translation
+ """
+ input_vars = {}
+ ir_inputs = _get_graph_input_names(graph)
+ for ir_input, (name, shape) in zip(ir_inputs, input_shapes):
+ inp = _expr.var(name, shape=shape)
+ # Translate from graph input to user input name
+ input_vars[ir_input] = inp
+
+ return input_vars
def get_use_chains(root_node, terminate=lambda _: False):
return params, param_tensors, packed_param_map
-def convert_block(block, outputs, output_index_map):
+def convert_block(block, outputs):
""" Translate Torch "Block", used for prim::If and prim::Loop """
ops = _get_operator_nodes(block.nodes())
ret_names = _get_input_names(block.returnNode())
- return convert_operators(ops, outputs, output_index_map, ret_names)
+ return convert_operators(ops, outputs, ret_names)
-def convert_if(if_node, outputs, output_index_map):
+def convert_if(if_node, outputs):
""" Translate Torch prim::If to Relay If """
- cond = outputs[output_index_map[if_node.inputsAt(0).debugName()]]
+ cond = outputs[if_node.inputsAt(0).debugName()]
blocks = list(if_node.blocks())
- true_branch = convert_block(blocks[0], outputs, output_index_map)
- false_branch = convert_block(blocks[1], outputs, output_index_map)
+ true_branch = convert_block(blocks[0], outputs)
+ false_branch = convert_block(blocks[1], outputs)
assert len(true_branch) == 1 and len(false_branch) == 1
return _expr.If(cond, true_branch[0], false_branch[0])
-def convert_loop(loop_node, outputs, output_index_map):
+def convert_loop(loop_node, outputs):
""" Translate Torch prim::Loop to Relay while_loop """
def get_input(index):
ivalue = loop_node.inputsAt(index)
if inode.kind() == "prim::Constant":
return _expr.const(_get_constant(inode))
var_name = ivalue.debugName()
- assert var_name in output_index_map
- return _wrap_const(outputs[output_index_map[var_name]])
+ assert var_name in outputs
+ return _wrap_const(outputs[var_name])
# Refer to the spec for prim::Loop below
# https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#loops
# Update loop variables using the prev iteration outputs
assert len(current_vals) == len(block_input_names)
for (i, iname) in enumerate(block_input_names):
- outputs[output_index_map[iname]] = current_vals[i]
+ outputs[iname] = current_vals[i]
- block_outputs = convert_block(body_block, outputs, output_index_map)
+ block_outputs = convert_block(body_block, outputs)
if not is_while_loop:
# iter var increment implicit in torch, so do it manually
name_val_pairs = list(zip(block_input_names,
[init_loop_iter_val] + init_vals))
- _update_outputs_from_pairs(name_val_pairs, outputs, output_index_map)
+ outputs.update(name_val_pairs)
loop_iter_var = _expr.var(block_input_names[0], shape=(),
dtype=loop_iter_dtype)
return [_expr.TupleGetItem(loop_val, i+1) for i in range(num_loop_var)]
-def convert_operators(operators, outputs, output_index_map, ret_names):
+def convert_operators(operators, outputs, ret_names):
""" Convert each Torch IR operators to Relay equivalent """
for node_name, op_node in operators:
operator = op_node.kind()
- inputs = _get_op_inputs(op_node, outputs, output_index_map)
+ inputs = _get_op_inputs(op_node, outputs)
if operator == "prim::Constant":
- output_index_map[node_name] = len(outputs)
- outputs.append(_get_constant(op_node))
+ outputs[node_name] = _get_constant(op_node)
elif operator == 'prim::ListConstruct' and _is_int_seq(inputs):
- output_index_map[node_name] = len(outputs)
- outputs.append(_expr.var(node_name, shape=inputs))
+ outputs[node_name] = _expr.var(node_name, shape=inputs)
elif operator in ['prim::ListConstruct', 'prim::TupleConstruct']:
- output_index_map[node_name] = len(outputs)
- outputs.append(inputs)
+ outputs[node_name] = inputs
elif operator in ["prim::ListUnpack", 'prim::TupleUnpack']:
assert len(inputs) == 1
unpacked_names = _get_output_names(op_node)
- _update_outputs_from_pairs(zip(unpacked_names, inputs[0]),
- outputs, output_index_map)
+ outputs.update(zip(unpacked_names, inputs[0]))
elif operator == "prim::If":
- if_out = convert_if(op_node, outputs, output_index_map)
- output_index_map[node_name] = len(outputs)
- outputs.append(if_out)
+ if_out = convert_if(op_node, outputs)
+ outputs[node_name] = if_out
elif operator == "prim::Loop":
- loop_out = convert_loop(op_node, outputs, output_index_map)
+ loop_out = convert_loop(op_node, outputs)
unpacked_names = _get_output_names(op_node)
assert len(loop_out) == len(unpacked_names)
- _update_outputs_from_pairs(zip(unpacked_names, loop_out),
- outputs, output_index_map)
+ outputs.update(zip(unpacked_names, loop_out))
else:
relay_op = _convert_map[operator]
relay_out = relay_op(inputs, _get_input_types(op_node))
# This is for torch operators that return multiple outputs
# See _adaptive_max_2d above for example
out_names = _get_output_names(op_node)
- _update_outputs_from_pairs(zip(out_names, relay_out),
- outputs, output_index_map)
+ outputs.update(zip(out_names, relay_out))
else:
- output_index_map[node_name] = len(outputs)
- outputs.append(relay_out)
+ outputs[node_name] = relay_out
- return [_wrap_const(outputs[output_index_map[ret_name]])
+ return [_wrap_const(outputs[ret_name])
for ret_name in ret_names]
return set(node.kind() for node in nodes)
-def get_graph_input_names(script_module):
- """ Use this function to set the keys for input_shapes"""
- # It seems variable names could change the first time a copy is made
- # Use the copy of the graph here to prevent troubles later
- ir_inputs = _get_input_names(script_module.graph.copy())
+def _get_graph_input_names(graph):
+ """ Get the graph input names (use after graph copy and run jit passes) """
+ # Variable names could change the first time a copy is made and after
+ # _run_jit_passes is called, expected that those functions already invoked
+ ir_inputs = _get_input_names(graph)
return ir_inputs[1:] # remove self at the 0th arg
TorchScripted PyTorch graph
Note: We currently only support traces (ie: torch.jit.trace(model, input))
- input_shapes : Dictionary of input dimensions
- Graph level input shape dictionary
- The keys should be the same one returned by get_graph_input_names(...) above
+ input_shapes : List of tuples of input name and input dimensions
+ Graph level input shape list
+ The same input names need to be used for deployment, so choose easy to
+ remember names (such as: input0, input1)
custom_convert_map: Dictionary of str to Relay op
A custom op conversion map in the same format as _convert_map above
op_names = get_all_op_names(graph)
_report_missing_conversion(op_names)
- _check_input_names(script_module, input_shapes)
+ _check_inputs(graph, input_shapes)
params = script_module.state_dict()
- input_vars = _get_relay_input_vars(input_shapes)
+ outputs = _get_relay_input_vars(graph, input_shapes)
param_vars, tensors, packed_param_map = convert_params(graph, params)
tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}
- input_vars.update(param_vars)
- outputs = list(input_vars.values())
- output_index_map = dict(zip(input_vars.keys(), range(len(outputs))))
+ outputs.update(param_vars)
ret_name = _get_input_names(graph.return_node())
# For quantized models
if "aten::quantize_per_tensor" in op_names:
weight_quant_params = qnn_torch.get_weight_quant_params(script_module)
qnn_torch.add_input_quant_params_to_op_inputs(graph)
- qnn_torch.add_quant_params_to_outputs(outputs, output_index_map,
+ qnn_torch.add_quant_params_to_outputs(outputs,
packed_param_map,
weight_quant_params)
qnn_torch.add_quant_params(tvm_params, weight_quant_params)
_convert_map.update(qnn_torch.convert_map)
- ret = convert_operators(_get_operator_nodes(graph.nodes()), outputs,
- output_index_map, ret_name)
+ ret = convert_operators(_get_operator_nodes(graph.nodes()),
+ outputs, ret_name)
if isinstance(ret[0], list):
ret[0] = _expr.Tuple(ret[0])