From 06e9542ee0bfd014bd06a4dd4fdb3af9d2d29eb0 Mon Sep 17 00:00:00 2001 From: masahi Date: Wed, 11 Mar 2020 02:59:09 +0900 Subject: [PATCH] [Torch] Add initial control flow support (#4964) * Add support for prim::If and prim::Loop with test cases * rebase and fix tests * add some comments * simplifying, fix float cast * parse -> convert * recursivly retrive ops in get_all_op_names * use multiple return values from block correctly, simplify loop convert * choose dtype properly for zeros and ones * simplifying, replace convert_inputs with _get_relay_input_vars * fix for while loop with non input dependent init cond * add assert on loop var update * move the condition around * better testing for seg models * rebase fix, disable inception v3 in quant test as it is too slow to load with torch-1.4 + torchvision 0.5 * simplify and add more comparison op converter --- python/tvm/relay/frontend/pytorch.py | 223 ++++++++++++++++++++++---- tests/python/frontend/pytorch/qnn_test.py | 3 +- tests/python/frontend/pytorch/test_forward.py | 197 ++++++++++++++++++++++- 3 files changed, 385 insertions(+), 38 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index ff37f82..6da91c1 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -20,6 +20,7 @@ """PT: PyTorch frontend.""" import itertools import logging +import sys import numpy as np @@ -29,6 +30,7 @@ from tvm.ir import module as _module from .. import analysis as _analysis from .. import expr as _expr from .. import op as _op +from ..loops import while_loop from .common import get_relay_op from .common import infer_shape as _infer_shape from .common import infer_value as _infer_value @@ -107,9 +109,8 @@ def _select(): def _impl(inputs, input_types): data = inputs[0] dim = int(inputs[1]) - index = int(inputs[2]) - - return _op.transform.take(data, _expr.const(index, dtype="int32"), axis=dim) + index = _wrap_const(inputs[2]) + return _op.transform.take(data, index, axis=dim) return _impl def _ones(): @@ -126,7 +127,10 @@ def _ones(): else: assert "data type {} could not be parsed in ones op" % (type(data)) - return _op.full(_expr.const(1), shape, dtype=_convert_data_type(input_types[0])) + dtype_map = {6: "float32", 3: "int32"} + dtype_id = inputs[1] + assert dtype_id in dtype_map, "Unsupported dtype %d" % dtype_id + return _op.full(_expr.const(1), shape, dtype=dtype_map[dtype_id]) return _impl def _zeros(): @@ -143,7 +147,10 @@ def _zeros(): else: assert "data type {} could not be parsed in zeros op" % (type(data)) - return _op.full(_expr.const(0), shape, dtype=_convert_data_type(input_types[0])) + dtype_map = {6: "float32", 3: "int32"} + dtype_id = inputs[1] + assert dtype_id in dtype_map, "Unsupported dtype %d" % dtype_id + return _op.full(_expr.const(0), shape, dtype=dtype_map[dtype_id]) return _impl def _relu(): @@ -222,12 +229,10 @@ def _convolution(): else: assert "data type {} could not be parsed in conv op" % (type(weight)) - # TODO: Add reshape when channel multiplier > 1. Pending PR #4644 channels = weight_shape[0] groups = int(inputs[8]) if groups > 1: - # in torch, groups == in_channels for depth wise conv channel_multiplier = channels // groups new_weight_shape = (groups, channel_multiplier, weight_shape[2], weight_shape[3]) weight = _op.transform.reshape(weight, new_weight_shape) @@ -496,7 +501,7 @@ def _dropout(): return _impl def _reduce(name): - def _impl(inputs, attrs, params): + def _impl(inputs, input_types): data = inputs[0] return get_relay_op(name)(data) return _impl @@ -714,7 +719,6 @@ def _upsample(method): return _impl - def _expand_as(): def _impl(inputs, input_types): # TODO: maybe fix this @@ -724,6 +728,29 @@ def _expand_as(): return inputs[0] return _impl +def _neg(): + def _impl(inputs, input_types): + data = inputs[0] + return _op.tensor.negative(data) + return _impl + +def _tanh(): + def _impl(inputs, input_types): + data = inputs[0] + return _op.tensor.tanh(data) + return _impl + +def _Bool(): + def _impl(inputs, input_types): + assert len(inputs) == 1 + return inputs[0] + return _impl + +def _Float(): + def _impl(inputs, input_types): + assert len(inputs) == 1 + return _op.cast(inputs[0], "float32") + return _impl # Helper functions for operator implementation @@ -780,6 +807,11 @@ def _convert_elemwise_input(data, input_type): else: return data +def _wrap_const(c): + if not isinstance(c, _expr.Expr) and not isinstance(c, list): + return _expr.const(c) + return c + # Operator mappings _convert_map = { @@ -845,7 +877,16 @@ _convert_map = { "aten::detach" : _identity(), "aten::upsample_bilinear2d" : _upsample("bilinear"), "aten::upsample_nearest2d" : _upsample("nearest_neighbor"), - "aten::expand_as" : _expand_as() + "aten::expand_as" : _expand_as(), + "aten::lt" : _elemwise("less"), + "aten::gt" : _elemwise("greater"), + "aten::le" : _elemwise("less_equal"), + "aten::ge" : _elemwise("greater_equal"), + "aten::ne" : _elemwise("not_equal"), + "aten::Bool" : _Bool(), + "aten::Float" : _Float(), + "aten::neg" : _neg(), + "aten::tanh" : _tanh(), } @@ -894,7 +935,8 @@ def _report_missing_conversion(op_names): """ Check if all ops in an input graph are supported by TVM """ known_ops = ["prim::Constant", "prim::GetAttr", "prim::ListConstruct", "prim::ListUnpack", - "prim::TupleConstruct", "prim::TupleUnpack"] + "prim::TupleConstruct", "prim::TupleUnpack", + "prim::If", "prim::Loop"] known_ops += list(_convert_map.keys()) known_ops += list(qnn_torch.convert_map.keys()) @@ -939,9 +981,13 @@ def _get_input_types(op_node): input_node_kind = in_ty.kind() if input_node_kind == 'TensorType': if in_ty.scalarType() is None: - input_list_types.append(None) + # Tensor's type can be unknown if we use torch.jit.script(...) + # Defaults to float for now + logging.warning("Untyped Tensor found, assume it is float") + input_list_types.append("float") else: input_list_types.append(in_ty.scalarType().lower()) + elif input_node_kind == 'ListType': input_list_types.append(str(in_ty.getElementType()).lower()) elif input_node_kind in ['IntType', 'FloatType', 'BoolType', @@ -1004,15 +1050,10 @@ def _get_operator_nodes(nodes): return ops -def parse_inputs(graph_inputs, input_shapes): - """ Return Relay vars from torch input vars """ - ir_inputs = list(graph_inputs) - input_vars = {} - - for input_name, ir_input in zip(input_shapes, ir_inputs[1:]): - input_vars[input_name] = _expr.var(input_name, - shape=input_shapes[input_name]) - return input_vars +def _get_relay_input_vars(input_shapes): + """ Return Relay vars from input shapes """ + return {iname: _expr.var(iname, shape=ishape) + for iname, ishape in input_shapes.items()} def get_use_chains(root_node, terminate=lambda _: False): @@ -1055,7 +1096,7 @@ def get_attr_chains(root_getattr_node): return get_use_chains(root_getattr_node, terminate) -def parse_params(graph, state_dict): +def convert_params(graph, state_dict): """ Return Relay vars and TVM NDArrays for input parameters A chain of prim::GetAttr nodes is processed one at a time @@ -1090,7 +1131,109 @@ def parse_params(graph, state_dict): return params, param_tensors, packed_param_map -def parse_operators(operators, outputs, output_index_map, ret_name): +def convert_block(block, outputs, output_index_map): + """ Translate Torch "Block", used for prim::If and prim::Loop """ + ops = _get_operator_nodes(block.nodes()) + ret_names = _get_input_names(block.returnNode()) + return convert_operators(ops, outputs, output_index_map, ret_names) + + +def convert_if(if_node, outputs, output_index_map): + """ Translate Torch prim::If to Relay If """ + cond = outputs[output_index_map[if_node.inputsAt(0).debugName()]] + blocks = list(if_node.blocks()) + true_branch = convert_block(blocks[0], outputs, output_index_map) + false_branch = convert_block(blocks[1], outputs, output_index_map) + assert len(true_branch) == 1 and len(false_branch) == 1 + return _expr.If(cond, true_branch[0], false_branch[0]) + + +def convert_loop(loop_node, outputs, output_index_map): + """ Translate Torch prim::Loop to Relay while_loop """ + def get_input(index): + ivalue = loop_node.inputsAt(index) + inode = ivalue.node() + if inode.kind() == "prim::Constant": + return _expr.const(_get_constant(inode)) + var_name = ivalue.debugName() + assert var_name in output_index_map + return _wrap_const(outputs[output_index_map[var_name]]) + + # Refer to the spec for prim::Loop below + # https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#loops + # The first input: %max_trip_count + # The second input: %initial_condition + # The rest of input: loop variables + max_loop_count = get_input(0) + init_cond = get_input(1) + num_loop_var = len(list(loop_node.inputs())) - 2 + init_vals = [get_input(i + 2) for i in range(num_loop_var)] + + # while loop has always max_loop_count being int64 max + # max_loop_count.data (tvm.runtime.NDArray) is -1, so _get_constant again + is_while_loop = (isinstance(max_loop_count, _expr.Constant) and + _get_constant(loop_node.inputsAt(0).node()) == sys.maxsize) + + body_block = list(loop_node.blocks())[0] + block_input_names = _get_input_names(body_block) + + def cond(*current_vals): + i = current_vals[0] + + if is_while_loop: + return _op.equal(i, _expr.const(True, 'bool')) + + return _op.less(i, max_loop_count) + + def body(*current_vals): + # Update loop variables using the prev iteration outputs + assert len(current_vals) == len(block_input_names) + for (i, iname) in enumerate(block_input_names): + outputs[output_index_map[iname]] = current_vals[i] + + block_outputs = convert_block(body_block, outputs, output_index_map) + + if not is_while_loop: + # iter var increment implicit in torch, so do it manually + # for while loop, block_outputs[0] is already a boolean, + # the result of termination check + incr = _expr.const(1, dtype="int32") + block_outputs[0] = current_vals[0] + incr + + return block_outputs + + def get_var(name, val): + if isinstance(val, _expr.Constant): + return _expr.var(name, shape=val.data.shape, dtype=val.data.dtype) + return _expr.var(name) + + if is_while_loop: + loop_iter_dtype = "bool" + # while loop with non input dependent condition such as while i < 10: + # init_cond is int, need to cast to bool to type check + if isinstance(init_cond, _expr.Constant): + init_cond = _op.cast(init_cond, "bool") + init_loop_iter_val = init_cond + else: + loop_iter_dtype = "int32" + # always count from 0 + init_loop_iter_val = _expr.const(0, dtype="int32") + + name_val_pairs = list(zip(block_input_names, + [init_loop_iter_val] + init_vals)) + _update_outputs_from_pairs(name_val_pairs, outputs, output_index_map) + + loop_iter_var = _expr.var(block_input_names[0], shape=(), + dtype=loop_iter_dtype) + loop_vars = [get_var(name, val) for name, val in name_val_pairs[1:]] + loop = while_loop(cond, [loop_iter_var] + loop_vars, body) + loop_val = loop(init_loop_iter_val, *init_vals) + + # The first element is a loop counter or boolean condition, ignore it + return [_expr.TupleGetItem(loop_val, i+1) for i in range(num_loop_var)] + + +def convert_operators(operators, outputs, output_index_map, ret_names): """ Convert each Torch IR operators to Relay equivalent """ for node_name, op_node in operators: operator = op_node.kind() @@ -1110,17 +1253,35 @@ def parse_operators(operators, outputs, output_index_map, ret_name): unpacked_names = _get_output_names(op_node) _update_outputs_from_pairs(zip(unpacked_names, inputs[0]), outputs, output_index_map) + elif operator == "prim::If": + if_out = convert_if(op_node, outputs, output_index_map) + output_index_map[node_name] = len(outputs) + outputs.append(if_out) + elif operator == "prim::Loop": + loop_out = convert_loop(op_node, outputs, output_index_map) + unpacked_names = _get_output_names(op_node) + assert len(loop_out) == len(unpacked_names) + _update_outputs_from_pairs(zip(unpacked_names, loop_out), + outputs, output_index_map) else: output_index_map[node_name] = len(outputs) relay_op = _convert_map[operator] outputs.append(relay_op(inputs, _get_input_types(op_node))) - return outputs[output_index_map[ret_name]] + return [_wrap_const(outputs[output_index_map[ret_name]]) + for ret_name in ret_names] def get_all_op_names(graph): """ Return all operator names in the input graph """ - return set(node.kind() for node in graph.nodes()) + nodes = list(graph.nodes()) + prim_with_blocks = ["prim::If", "prim::Loop"] + for prim in prim_with_blocks: + prim_nodes = graph.findAllNodes(prim, recurse=True) + for prim_node in prim_nodes: + for block in prim_node.blocks(): + nodes += block.nodes() + return set(node.kind() for node in nodes) def get_graph_input_names(script_module): @@ -1167,14 +1328,14 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None): _check_input_names(script_module, input_shapes) params = script_module.state_dict() - input_vars = parse_inputs(graph.inputs(), input_shapes) - param_vars, tensors, packed_param_map = parse_params(graph, params) + input_vars = _get_relay_input_vars(input_shapes) + param_vars, tensors, packed_param_map = convert_params(graph, params) tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()} input_vars.update(param_vars) outputs = list(input_vars.values()) output_index_map = dict(zip(input_vars.keys(), range(len(outputs)))) - ret_name = _get_input_names(graph.return_node())[0] + ret_name = _get_input_names(graph.return_node()) # For quantized models if "aten::quantize_per_tensor" in op_names: @@ -1186,8 +1347,8 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None): qnn_torch.add_quant_params(tvm_params, weight_quant_params) _convert_map.update(qnn_torch.convert_map) - body = parse_operators(_get_operator_nodes(graph.nodes()), outputs, - output_index_map, ret_name) - func = tvm.relay.Function(_analysis.free_vars(body), body) + ret = convert_operators(_get_operator_nodes(graph.nodes()), outputs, + output_index_map, ret_name) + func = tvm.relay.Function(_analysis.free_vars(ret[0]), ret[0]) return _module.IRModule.from_expr(func), tvm_params diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index e3a876c..23fcb7c 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -347,7 +347,8 @@ def test_quantized_imagenet(): qmodels += [ ("resnet18", qresnet.resnet18(pretrained=True), per_channel), ("mobilenet_v2", qmobilenet.mobilenet_v2(pretrained=True), per_channel), - ("inception_v3", qinception.inception_v3(pretrained=True), per_channel), + # disable inception test for now, since loading it takes ~5min on torchvision-0.5 + #("inception_v3", qinception.inception_v3(pretrained=True), per_channel), ("googlenet", qgooglenet(pretrained=True), per_channel), ] diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index eed47ea..59f93b4 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -756,7 +756,6 @@ def test_vgg11_bn(): verify_model("vgg11_bn") """ - def test_custom_conversion_map(): def get_roi_align(): pool_size = 5 @@ -801,11 +800,193 @@ def test_segmentaton_models(): inp = [torch.rand((1, 3, 300, 300), dtype=torch.float)] - for model in [fcn, deeplab]: - # depthwise + dilated covolution not supported on x86 - # see https://github.com/apache/incubator-tvm/issues/4962 - verify_model(SegmentationModelWrapper(model.eval()), inp, - ctx_list=[("cuda", tvm.gpu(0))]) + verify_model(SegmentationModelWrapper(fcn.eval()), inp) + + # depthwise + dilated covolution not supported on x86 + # see https://github.com/apache/incubator-tvm/issues/4962 + cuda_ctx = ("cuda", tvm.gpu(0)) + if cuda_ctx[1].exist: + verify_model(SegmentationModelWrapper(deeplab.eval()), inp, [cuda_ctx]) + + +def verify_script_model(pt_model, ishapes): + script_module = torch.jit.script(pt_model) + input_names = get_graph_input_names(script_module) + input_shapes = dict(zip(input_names, ishapes)) + + inputs = [torch.randn(input_shapes[input_name], dtype=torch.float) + for input_name in input_names] + + mod, params = relay.frontend.from_pytorch(script_module, input_shapes) + + executor = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0), + target="llvm") + evaluator = executor.evaluate() + + for name, inp in zip(input_names, inputs): + params[name] = inp.numpy() + + op_res = evaluator(**params) + + with torch.no_grad(): + pt_result = pt_model(*inputs) + + if not isinstance(pt_result, torch.Tensor): + tvm_res = op_res.asnumpy().item() + assert pt_result == tvm_res + else: + tvm.testing.assert_allclose(op_res.asnumpy(), pt_result.numpy(), + rtol=1e-5, atol=1e-5) + + +def test_control_flow(): + class SimpleIf(torch.nn.Module): + def __init__(self, N, M): + super().__init__() + self.weight = torch.nn.Parameter(torch.rand(N, M)) + + def forward(self, inp): + if inp.sum() > 0.: + output = self.weight + inp + else: + output = self.weight - inp + return output + + class NestedIf(torch.nn.Module): + def __init__(self, N, M): + super().__init__() + self.weight = torch.nn.Parameter(torch.rand(N, M)) + + def forward(self, inp): + if inp.sum() > 0.: + if inp.mean() > 0.: + output = self.weight + inp + else: + output = self.weight - inp + else: + if inp.mean() >= 0.: + output = self.weight * inp + else: + output = self.weight / inp + + return output + + class ScalarLoop(torch.nn.Module): + def forward(self, inp): + a = 0 + for i in range(inp.size(0)): + b = i * i + b = b + 1 + a += b + if a != 0: + a += 1 + else: + a += 2 + return a + + class SimpleLoop(torch.nn.Module): + def forward(self, inp): + a = inp + for i in range(inp.size(0)): + b = a * 2. + c = a + b + a += c + return a + + class LoopWithIf(torch.nn.Module): + def forward(self, inp): + a = inp + for i in range(inp.size(0)): + b = a * 2. + b = a + b + if b.sum() > 0.0: + a += b + else: + a -= b + return a + + class NestedLoop(torch.nn.Module): + def forward(self, inp): + a = inp + for i in range(inp.size(0)): + b = a * float(i) + for j in range(inp.size(1)): + a += b * float(j) + return a + + class SimpleScalarWhileLoop(torch.nn.Module): + def forward(self, inp): + a = 1 + i = 0 + while i <= inp.size(0): + a += i + i += 2 + i = 0 + # also test constant init cond + while i < 10: + a += i + i += 3 + return a + + class SimpleWhileLoop(torch.nn.Module): + def forward(self, inp): + a = inp + i = 0 + while i < inp.size(0): + a += a * float(i) * 2.0 + i += 1 + return a + + models = [ + SimpleIf(10, 20), + NestedIf(10, 20), + ScalarLoop(), + SimpleLoop(), + LoopWithIf(), + SimpleScalarWhileLoop(), + SimpleWhileLoop(), + NestedLoop(), + ] + + for pt_model in models: + verify_script_model(pt_model.eval(), [(10, 20)]) + + +def test_simple_rnn(): + # The mixed tracing and scripting example from + # https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html#mixing-scripting-and-tracing + class DecisionGate(torch.nn.Module): + def forward(self, x): + if x.sum() > 0: + return x + else: + return -x + + class Cell(torch.nn.Module): + def __init__(self, dg): + super(Cell, self).__init__() + self.dg = dg + self.linear = torch.nn.Linear(4, 4) + + def forward(self, x, h): + new_h = torch.tanh(self.dg(self.linear(x)) + h) + return new_h, new_h + + class RNNLoop(torch.nn.Module): + def __init__(self): + super().__init__() + x = torch.rand(10, 4, dtype=torch.float) + h = torch.rand(10, 4, dtype=torch.float) + self.cell = torch.jit.trace(Cell(DecisionGate()), (x, h)) + + def forward(self, xs): + h = torch.zeros(10, 4, dtype=torch.float) + y = torch.zeros(10, 4, dtype=torch.float) + for i in range(xs.size(0)): + y, h = self.cell(xs[i], h) + return y + + verify_script_model(RNNLoop().eval(), [(10, 10, 4)]) if __name__ == "__main__": @@ -860,3 +1041,7 @@ if __name__ == "__main__": test_quantized_modules() test_quantized_imagenet() + + # Test simple conditionals and loop + test_control_flow() + test_simple_rnn() -- 2.7.4