From 7bc0b27ecb4de359243937a4b3954857a64f44fd Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Mon, 23 Mar 2020 10:28:19 -0700 Subject: [PATCH] [Frontend][TensorFlow]TensorFlow Parser Control Flow Enhancement (#5020) * Improve TF control flow major logic * Pass mod into operator convert function * Fix LoopBound * Add more control flow tests * Add two test cases for stridedslice * Fix docstring * Fix lint * Fix import * Fix test assert * Minor fix conv3d * Add more comments * Fix for dilation2d * Change newly added atan * Change newly added unravel --- python/tvm/relay/frontend/common.py | 42 +- python/tvm/relay/frontend/tensorflow.py | 843 +++++++++++++-------- .../frontend/tensorflow/test_control_flow.py | 71 +- tests/python/frontend/tensorflow/test_debugging.py | 8 +- tests/python/frontend/tensorflow/test_forward.py | 2 + 5 files changed, 641 insertions(+), 325 deletions(-) diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 6185121..5465e50 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=broad-except """Common utilities""" from __future__ import absolute_import as _abs import logging @@ -482,24 +483,37 @@ def infer_channels(inputs, transpose=False): return channels -def infer_value(input_val, params): +def infer_value(input_val, params, mod=None): """A hack for getting the value of an expression by evaluating a portion of the relay graph. This is often needed for functions that whose output shape depends on the value of a tensor. """ - # pylint: disable=import-outside-toplevel - from tvm.contrib import graph_runtime - # Check that all free variables have associated parameters. - assert all(var.name_hint in params.keys() for var in analysis.free_vars( - input_val)), "All inputs to infer must be available in params." - func = _function.Function(analysis.free_vars(input_val), input_val) - with tvm.relay.build_config(opt_level=0): - graph, lib, params = tvm.relay.build(func, target="llvm", params=params) - ctx = tvm.cpu(0) - m = graph_runtime.create(graph, lib, ctx) - m.set_input(**params) - m.run() - return m.get_output(0) + try: + # TODO(kevinthesun): Use VM for all cases. + # pylint: disable=import-outside-toplevel + from tvm.contrib import graph_runtime + # Check that all free variables have associated parameters. + assert all(var.name_hint in params.keys() for var in analysis.free_vars( + input_val)), "All inputs to infer must be available in params." + func = _function.Function(analysis.free_vars(input_val), input_val) + with tvm.relay.build_config(opt_level=0): + graph, lib, params = tvm.relay.build(func, target="llvm", params=params) + ctx = tvm.cpu(0) + m = graph_runtime.create(graph, lib, ctx) + m.set_input(**params) + m.run() + return m.get_output(0) + except Exception: + if isinstance(mod, IRModule): + mod["main"] = _expr.Function(analysis.free_vars(input_val), input_val) + else: + mod = IRModule.from_expr(input_val) + exc = tvm.relay.create_executor("debug", mod=mod, ctx=tvm.cpu(), target="llvm") + inputs = [] + for param in mod['main'].params: + inputs.append(tvm.nd.array(params[param.name_hint])) + result = exc.evaluate()(*inputs) + return result def infer_value_simulated(input_val, params): diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 9cdd68b..d0b90e5 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -23,17 +23,17 @@ from collections import defaultdict # Numpy support import numpy as np - import tvm from tvm.ir import IRModule from tvm.relay.prelude import Prelude +from tvm.relay.analysis import structural_hash as s_hash from .. import analysis from .. import expr as _expr from .. import function as _function from .. import op as _op -from ..expr_functor import ExprMutator +from ..expr_functor import ExprMutator, ExprVisitor from .common import AttrCvt, get_relay_op from .common import infer_type as _infer_type from .common import infer_shape as _infer_shape @@ -92,43 +92,40 @@ def _get_list_param(params, input_node): def _get_tuple_param(params, input_node): return tuple(_get_param(params, input_node)) -def _need_module_for_shape_inference(op): - return op in ['StridedSlice'] - def _need_prelude_for_shape_inference(op): return "TensorArray" in op def _rsqrt(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): inputs.append(tvm.relay.const(-0.5, attr['T'].name)) return AttrCvt(op_name="power")(inputs, attr) return _impl def _argx(func, func_name): """ A common wrapper for argmin and argmax operations """ - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): try: # In Tensorflow, `axis` argument is a Tensor, not attribute. We # support the case where it inputs from a scalar constant. axis_input_value = [_get_num_param(params, inputs[1])] except (IndexError, KeyError): - raise TypeError( \ + raise TypeError( "Unsupported argument for `{}` : `axis` should be a constant".format(func_name)) return func(inputs[0], axis=axis_input_value, keepdims=False) return _impl def _elemwise(name): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): assert len(inputs) == 2, "{} take 2 inputs, {} given".format(name, len(inputs)) return get_relay_op(name)(*inputs) return _impl def _pool3d(name): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): attr['data_format'] = attr['data_format'].decode("utf-8") flip_layout = False - input_shape = attr['_input_shapes'][inputs[0]] + input_shape = _infer_shape(inputs[0], mod) if attr['data_format'] == 'NDHWC': attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2], attr['ksize'][3]) @@ -141,10 +138,9 @@ def _pool3d(name): 'is not valid.' raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format'])) if attr['data_format'] == "NDHWC": - input_shape = [attr['_input_shapes'][inputs[0]][i] for i in (0, 4, 1, 2, 3)] + input_shape = [_infer_shape(inputs[0], mod)[i] for i in (0, 4, 1, 2, 3)] inputs[0] = _op.transpose(inputs[0], axes=(0, 4, 1, 2, 3)) attr['data_format'] = "NCDHW" - attr['_input_shapes'][inputs[0]] = input_shape flip_layout = True attr['padding'] = attr['padding'].decode("utf-8") @@ -188,12 +184,12 @@ def _pool3d(name): return _impl def _pooling(name): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): attr['data_format'] = attr['data_format'].decode("utf-8") flip_layout = False - input_shape = attr['_input_shapes'][inputs[0]] + input_shape = _infer_shape(inputs[0], mod) if attr['data_format'] == 'NHWC': attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2]) @@ -207,7 +203,7 @@ def _pooling(name): raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format'])) if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": - tmp_shape = attr['_input_shapes'][inputs[0]] + tmp_shape = _infer_shape(inputs[0], mod) input_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)] inputs[0] = _op.transpose(inputs[0], axes=(0, 3, 1, 2)) attr['data_format'] = "NCHW" @@ -256,17 +252,16 @@ def _pooling(name): return _impl def _conv(opname): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): attr['data_format'] = attr['data_format'].decode("utf-8") flip_layout = False if opname == 'conv_transpose' and attr['data_format'] == 'NHWC': # transform to NCHW for TVM backend compatible and set 'flip_layout' # to have output flip back to NHWC - tmp_shape = attr['_input_shapes'][inputs[2]] + tmp_shape = _infer_shape(inputs[2], mod) tmp_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)] inputs[2] = _op.transpose(inputs[2], axes=(0, 3, 1, 2)) - attr['_input_shapes'][inputs[2]] = tmp_shape attr['strides'][1], attr['strides'][2], attr['strides'][3] = \ attr['strides'][3], attr['strides'][1], attr['strides'][2] attr['data_format'] = 'NCHW' @@ -281,19 +276,19 @@ def _conv(opname): inputs_data = inputs[0] if opname != 'conv_transpose' else inputs[2] # NCHW Layout require weights transpose + weights_shape = _infer_shape(inputs[1]) if attr['data_format'] == 'NCHW': - tmp_shape = attr['_input_shapes'][inputs[1]] + tmp_shape = weights_shape if opname in ['conv', 'conv_transpose']: tmp_shape = [tmp_shape[ii] for ii in (3, 2, 0, 1)] inputs[1] = _op.transpose(inputs[1], axes=(3, 2, 0, 1)) else: tmp_shape = [tmp_shape[ii] for ii in (2, 3, 0, 1)] inputs[1] = _op.transpose(inputs[1], axes=(2, 3, 0, 1)) - attr['_input_shapes'][inputs[1]] = tmp_shape + weights_shape = tmp_shape - input_shape = attr['_input_shapes'][inputs_data] - weights_shape = attr['_input_shapes'][inputs[1]] + input_shape = _infer_shape(inputs_data) if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] inputs_data = _op.transpose(inputs_data, axes=(0, 3, 1, 2)) @@ -390,8 +385,8 @@ def _conv(opname): # Ignore the new attributes from TF2.0, for now. out = AttrCvt( - op_name=_dimension_picker('conv', \ - surfix="_transpose" if opname == 'conv_transpose' else ""), + op_name=_dimension_picker('conv', + surfix="_transpose" if opname == 'conv_transpose' else ""), ignores=['explicit_paddings'], transforms={ 'kernel_shape': 'kernel_size', @@ -414,12 +409,12 @@ def _conv(opname): # Dilation2d def _dilation2d(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): if 'data_format' not in attr: attr['data_format'] = 'NHWC' - input_shape = attr['_input_shapes'][inputs[0]] - weights_shape = attr['_input_shapes'][inputs[1]] + input_shape = _infer_shape(inputs[0], mod) + weights_shape = _infer_shape(inputs[1], mod) if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] @@ -497,21 +492,21 @@ def _dilation2d(): def _conv3d(opname): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): attr['data_format'] = attr['data_format'].decode("utf-8") flip_layout = False inputs_data = inputs[0] if opname != 'conv_transpose' else inputs[2] # NCDHW Layout require weights transpose + weights_shape = _infer_shape(inputs[1], mod) if attr['data_format'] == 'NCDHW': - tmp_shape = attr['_input_shapes'][inputs[1]] + tmp_shape = weights_shape tmp_shape = [tmp_shape[ii] for ii in (4, 3, 0, 1, 2)] inputs[1] = _op.transpose(inputs[1], axes=(4, 3, 0, 1, 2)) - attr['_input_shapes'][inputs[1]] = tmp_shape + weights_shape = tmp_shape - input_shape = attr['_input_shapes'][inputs_data] - weights_shape = attr['_input_shapes'][inputs[1]] + input_shape = _infer_shape(inputs_data, mod) if attr['_target_layout'] == "NCDHW" and attr['data_format'] == "NDHWC": input_shape = [input_shape[ii] for ii in (0, 4, 1, 2, 3)] @@ -532,7 +527,7 @@ def _conv3d(opname): attr['channels'] = weights_shape[3] if 'dilations' in attr: - attr['dilations'] =\ + attr['dilations'] = \ (attr['dilations'][1], attr['dilations'][2], attr['dilations'][3]) attr['strides'] = (attr['strides'][1], attr['strides'][2], attr['strides'][3]) elif attr['data_format'] == 'NCDHW': @@ -544,7 +539,7 @@ def _conv3d(opname): attr['channels'] = weights_shape[1] if 'dilations' in attr: - attr['dilations'] =\ + attr['dilations'] = \ (attr['dilations'][2], attr['dilations'][3], attr['dilations'][4]) attr['strides'] = (attr['strides'][2], attr['strides'][3], attr['strides'][4]) else: @@ -599,8 +594,8 @@ def _conv3d(opname): # Ignore the new attributes from TF2.0, for now. out = AttrCvt( - op_name=_dimension_picker('conv', \ - surfix="_transpose" if opname == 'conv_transpose' else ""), + op_name=_dimension_picker('conv', + surfix="_transpose" if opname == 'conv_transpose' else ""), ignores=['explicit_paddings'], transforms={ 'kernel_shape': 'kernel_size', @@ -621,19 +616,19 @@ def _conv3d(opname): return _impl def _decode_image(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): # Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer. warnings.warn("DecodeJpeg: It's a pass through, please handle preprocessing before input") return inputs[0] return _impl def _unravel_index(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): return _op.unravel_index(inputs[0], inputs[1]) return _impl def _crop_and_resize(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): # input image is a 4-D tensor of shape [batch, image_height, image_width, depth] # boxes is a 2-D tensor of shape [num_boxes, 4], 4 is for [y1, x1, y2, x2] try: @@ -654,12 +649,12 @@ def _crop_and_resize(): return _impl def _cast(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): return inputs[0].astype(attr['DstT'].name) return _impl def _expand_dims(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): dim_input = inputs.pop(1) axis = _get_num_param(params, dim_input) return AttrCvt(op_name="expand_dims", ignores=['Tdim', 'N'], @@ -667,14 +662,16 @@ def _expand_dims(): return _impl def _resize(method): - def _impl(inputs, attr, params): - output_shape0 = attr['_output_shapes'][0] - # Dynamic size models might have _output_shapes attr equal to [None] here - size = output_shape0[1:3] if output_shape0 is not None else [-1, -1] - # Important that the size is defined. If an axis is not, we need to infer what - # the shape should be. - if -1 in size: + def _impl(inputs, attr, params, mod): + if attr['_output_shapes'][0] is not None: + size = attr['_output_shapes'][0][1:3] + # Important that the size is defined. If an axis is not, we need to infer what + # the shape should be. + if -1 in size: + size = _infer_value(inputs[1], params).asnumpy().reshape([-1]).tolist() + else: size = _infer_value(inputs[1], params).asnumpy().reshape([-1]).tolist() + attr['size'] = size inputs.pop(1) # NHWC @@ -691,7 +688,7 @@ def _resize(method): return _impl def _check_numerics(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): # Making a copy node assuming no need to verify return AttrCvt(op_name="copy", ignores=['message'])(inputs, attr) return _impl @@ -704,7 +701,7 @@ def _assert(): return _no_op() def _no_op(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): # ToDo: This should really be an op that returns nothing, which could # be represented as an empty tuple. It turns out that TVM # infrastructure doesn't like running functions that return None and @@ -716,7 +713,7 @@ def _no_op(): return _impl def _matmul(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): channels = _infer_channels(inputs[1], not attr['transpose_b']) if attr['transpose_a']: inputs[0] = _op.transpose(inputs[0], axes=(1, 0)) @@ -729,11 +726,11 @@ def _matmul(): return _impl def _batch_matmul(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): input_x = inputs[0] input_y = inputs[1] - orig_shape_x = attr['_input_shapes'][input_x] - orig_shape_y = attr['_input_shapes'][input_y] + orig_shape_x = _infer_shape(input_x, mod) + orig_shape_y = _infer_shape(input_y, mod) # reshape n-dimensional batch matmul into 3d if len(orig_shape_x) > 3: @@ -761,12 +758,12 @@ def _batch_matmul(): return _impl def _identity(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): return inputs[0] return _impl def _concatV2(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): pop_node = inputs.pop(len(inputs)-1) axis = int(_get_num_param(params, pop_node)) return AttrCvt( @@ -775,7 +772,7 @@ def _concatV2(): return _impl def _concat(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): pop_node = inputs.pop(0) axis = int(_get_num_param(params, pop_node)) return AttrCvt( @@ -784,7 +781,7 @@ def _concat(): return _impl def _pack(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): axis = int(attr["axis"]) inputs_reshaped = [_op.expand_dims(i, axis=axis, num_newaxis=1) for i in inputs] return _op.concatenate(inputs_reshaped, axis) @@ -854,7 +851,7 @@ def _tensor_array_concat(): return _impl def _tile(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): reps = _get_list_param(params, inputs.pop()) new_input = [] new_input.append(inputs.pop(0)) @@ -866,7 +863,7 @@ def _tile(): return _impl def _slice(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): try: begin = _get_list_param(params, inputs[1]) except (IndexError, KeyError, AttributeError): @@ -874,21 +871,26 @@ def _slice(): try: size = _get_list_param(params, inputs[2]) except (IndexError, KeyError, AttributeError): - size = _infer_value(inputs[2], params).asnumpy().tolist()[0] - data_shape = attr['_input_shapes'][inputs[0]] + # Handle symbolic size + try: + size = _infer_value(inputs[2], params).asnumpy().tolist()[0] + except Exception: + size = inputs[2] + data_shape = _infer_shape(inputs[0], mod) data_dim = len(data_shape) end = size - for i in range(data_dim): - if size[i] == -1: - end[i] = data_shape[i] - else: - end[i] += begin[i] + if not isinstance(end, (_expr.Call, _expr.Var)): + for i in range(data_dim): + if size[i] == -1: + end[i] = data_shape[i] + else: + end[i] += begin[i] return _op.strided_slice(inputs[0], begin=begin, end=end) return _impl def _reshape(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): pop_node = inputs.pop(1) try: @@ -917,7 +919,7 @@ def _reshape(): def _depth_to_space(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): block_size = int(attr['block_size']) layout = attr['data_format'].decode("utf-8") return _op.nn.depth_to_space(inputs[0], block_size, layout) @@ -926,7 +928,7 @@ def _depth_to_space(): def _space_to_depth(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): block_size = int(attr['block_size']) layout = attr['data_format'].decode("utf-8") return _op.nn.space_to_depth(inputs[0], block_size, layout) @@ -935,7 +937,7 @@ def _space_to_depth(): def _bias_add(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): # Must expand for proper broadcasting in NCHW. if attr['data_format'].decode("utf-8") == 'NCHW': bias = _op.reshape(inputs[1], newshape=(1, -1, 1, 1)) @@ -945,7 +947,7 @@ def _bias_add(): return _impl def _broadcast_to(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): if isinstance(inputs[1], _expr.Var): shape = params[inputs[1].name_hint] else: @@ -955,7 +957,7 @@ def _broadcast_to(): return _impl def _squeeze(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): if len(attr['squeeze_dims']) == 0: attr['squeeze_dims'] = None return AttrCvt( @@ -965,7 +967,7 @@ def _squeeze(): return _impl def _fused_batch_norm(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): # Tensorflow: (data, gamma, beta, moving_mean, moving_variance) # Relay: (data, gamma, beta, moving_mean, moving_varience) assert len(inputs) == 5 @@ -1001,7 +1003,7 @@ def _fused_batch_norm(): return _impl def _batch_norm(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): # Rearrange inputs from # (data, moving_mean, moving_variance, beta, gamma) # to @@ -1023,14 +1025,15 @@ def _batch_norm(): return _impl def _relu6(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): return _op.clip(inputs[0], a_min=0, a_max=6) return _impl def _shape(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): is_symbolic_shape = False - for axis in attr['_input_shapes'][inputs[0]]: + input_shape = _infer_shape(inputs[0], mod) + for axis in input_shape: if not isinstance(axis, (int, tvm.tir.IntImm)): is_symbolic_shape = True break @@ -1038,13 +1041,13 @@ def _shape(): if is_symbolic_shape: ret = _op.shape_of(inputs[0], dtype='int32') else: - ret = np.array(attr['_input_shapes'][inputs[0]], dtype='int32') + ret = np.array(input_shape, dtype='int32') return ret return _impl def _fill(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): output_shape = attr['_output_shapes'][0] # Output shape must be defined to avoid errors. If any axis is not, we must # try to compute its shape. @@ -1058,7 +1061,7 @@ def _fill(): return _impl def _lrn(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): attr_new = {} depth_radius = attr.get('depth_radius', 5) size = (depth_radius * 2) + 1 @@ -1071,7 +1074,7 @@ def _lrn(): return _impl def _sum(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): axis = _get_tuple_param(params, inputs[1]) return AttrCvt( op_name='sum', @@ -1081,7 +1084,7 @@ def _sum(): return _impl def _reduce(op): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): axis = _get_list_param(params, inputs[1]) axis = tuple(axis) return AttrCvt( @@ -1092,13 +1095,13 @@ def _reduce(op): return _impl def _square(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): return _op.multiply(inputs[0], inputs[0]) return _impl def _gather(): "GatherV2, Gather" - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): if len(inputs) > 2: axis = _get_num_param(params, inputs.pop(2)) else: @@ -1115,7 +1118,7 @@ def _gather(): def _gather_nd(): """GatherNd""" - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): return AttrCvt(op_name="gather_nd", ignores=['Tindices', 'Tparams',\ 'Taxis', '_class'])(inputs, attr) @@ -1136,7 +1139,7 @@ def _stridedSlice(): ellipsis_mask = int(attr.get('ellipsis_mask', 0)) new_axis_mask = int(attr.get('new_axis_mask', 0)) shrink_axis_mask = int(attr.get('shrink_axis_mask', 0)) - data_shape = attr['_input_shapes'][inputs[0]] + data_shape = _infer_shape(inputs[0], mod) data_dim = len(data_shape) stride_dim = len(stride) @@ -1164,8 +1167,8 @@ def _stridedSlice(): mask = 1 << index if mask & ellipsis_mask: #Identify the end index for applying ellipsis_mask - to_index = min(((data_dim - (stride_dim-index)) + 1 \ - + new_axes_after_ellipsis), data_dim) + to_index = min(((data_dim - (stride_dim-index)) + 1 + + new_axes_after_ellipsis), data_dim) for i in range(final_index, to_index): m_begin[final_index] = 0 m_end[final_index] = data_shape[final_index] @@ -1205,7 +1208,7 @@ def _stridedSlice(): if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask: begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask) out = _op.strided_slice(inputs[0], begin=begin, end=end, strides=stride) - out_shape = _infer_shape(out, mod=mod) + out_shape = _infer_shape(out, mod) if not fshape_indices: fshape_indices = range(len(out_shape)) @@ -1220,12 +1223,25 @@ def _stridedSlice(): final_output.append(out_shape[gather_index]) if not final_output: - return out - return _op.reshape(out, newshape=tuple(final_output)) + if not shrink_axis_mask: + ret = out + else: + final_shape = [] + for dim in out_shape: + if dim != 1: + final_shape.append(dim) + if len(final_shape) == 0: + ret = _op.squeeze(out) + else: + # We need reshape to handle dynamic shape. + ret = _op.reshape(out, newshape=tuple(final_shape)) + else: + ret = _op.reshape(out, newshape=tuple(final_output)) + return ret return _impl def _pad(name): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): padlist = _get_param(params, inputs[1]) paddings = tuple(tuple(l) for l in padlist) attr['pad_width'] = paddings @@ -1240,7 +1256,7 @@ def _pad(name): return _impl def _mirror_pad(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): padlist = _get_param(params, inputs[1]) paddings = tuple(tuple(l) for l in padlist) attr['pad_width'] = paddings @@ -1253,7 +1269,7 @@ def _mirror_pad(): return _impl def _transpose(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): # If perm is not specified, axes is left empty, # otherwise its value is get from params try: @@ -1264,21 +1280,21 @@ def _transpose(): return _impl def _where(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): if len(inputs) == 1: return AttrCvt(op_name="argwhere")(inputs, attr) return AttrCvt(op_name="where")(inputs, attr) return _impl def _clip_by_value(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): a_min = _get_num_param(params, inputs[1]) a_max = _get_num_param(params, inputs[2]) return _op.clip(inputs[0], a_min=a_min, a_max=a_max) return _impl def _reverse_v2(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): axis = _get_num_param(params, inputs[1]) return AttrCvt( op_name="reverse", @@ -1287,8 +1303,8 @@ def _reverse_v2(): return _impl def _rank(): - def _impl(inputs, attr, params): - input_shape = attr['_input_shapes'][inputs[0]] + def _impl(inputs, attr, params, mod): + input_shape = _infer_shape(inputs[0], mod) name = attr["_node_name"] params[name] = tvm.nd.array([len(input_shape)]) @@ -1298,31 +1314,61 @@ def _rank(): return _impl - def _range(): - def _impl(inputs, attr, params): - start = _get_param(params, inputs[0])[0] + def _impl(inputs, attr, params, mod): + try: + start = _get_param(params, inputs[0])[0] + except (IndexError, KeyError, AttributeError): + try: + start = _infer_value(inputs[1], params).asnumpy().tolist() + start = start if not isinstance(start, list) else start[0] + except Exception: + # Symbolic start + start = inputs[0] + if hasattr(inputs[1], "name_hint") or isinstance(inputs[1], _expr.Constant): limit = _get_param(params, inputs[1])[0] else: if any(['Rank' in param for param in params]): limit = params.pop('Rank').asnumpy()[0] else: - limit = _infer_value_simulated(inputs[1], params).asnumpy()[0] - delta = _get_param(params, inputs[2])[0] + try: + limit = _infer_value(inputs[1], params, mod).asnumpy().tolist() + limit = limit if not isinstance(limit, list) else limit[0] + except Exception: + # Symbolic limit + limit = inputs[1] + + try: + delta = _get_param(params, inputs[2])[0] + except (IndexError, KeyError, AttributeError): + try: + delta = _infer_value(inputs[2], params, mod).asnumpy().tolist() + delta = delta if not isinstance(delta, list) else delta[0] + except Exception: + # Symbolic delta + delta = inputs[2] + + dtype = attr['Tidx'].name if 'Tidx' in attr else str(start.dtype) + if isinstance(start, (np.int32, np.int64, int, np.float32, np.float64, float)): + start = _expr.const(start) + if isinstance(limit, (np.int32, np.int64, int, np.float32, np.float64, float)): + limit = _expr.const(limit) + if isinstance(delta, (np.int32, np.int64, int, np.float32, np.float64, float)): + delta = _expr.const(delta) + return AttrCvt( op_name="arange", ignores=['Tidx'], - extras={'start': _expr.const(start), - "stop": _expr.const(limit), - 'step': _expr.const(delta), + extras={'start': start, + 'stop': limit, + 'step': delta, 'dtype': dtype})([], attr) return _impl - def _elu(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): dtype = attr['T'].name alpha = tvm.relay.const(-1.0, dtype) return alpha * _op.nn.relu(tvm.relay.const(1, dtype) \ @@ -1330,16 +1376,16 @@ def _elu(): return _impl def _selu(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): dtype = attr['T'].name alpha = tvm.relay.const(-1.6732632423543772848170429916717, dtype) gamma = tvm.relay.const(1.0507009873554804934193349852946, dtype) - return gamma * (alpha * _op.nn.relu(tvm.relay.const(1, dtype) \ + return gamma * (alpha * _op.nn.relu(tvm.relay.const(1, dtype) - _op.exp(inputs[0])) + _op.nn.relu(inputs[0])) return _impl def _mean(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): axis = _get_tuple_param(params, inputs[1]) return AttrCvt(op_name="mean", ignores=['Tdim', 'Tidx'], transforms={'keep_dims': 'keepdims'}, @@ -1347,7 +1393,7 @@ def _mean(): return _impl def _broadcast(name): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): return AttrCvt( op_name=name, ignores=['name', 'incompatible_shape_error', 'Tidx'] @@ -1356,7 +1402,7 @@ def _broadcast(name): def _split(has_size_vector): # TF documentation https://www.tensorflow.org/api_docs/python/tf/split - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): try: # order and number of inputs are different: # if has_size_vector: @@ -1379,8 +1425,8 @@ def _split(has_size_vector): input_node = inputs[input_node_index] axis_input_value = _get_num_param(params, inputs[input_axis_index]) except (IndexError, KeyError): - raise TypeError( \ - "Unsupported argument for split: `axis` and `num_or_size_splits` " \ + raise TypeError( + "Unsupported argument for split: `axis` and `num_or_size_splits` " "should be constants") return _op.split(input_node, indices_or_sections=indices_or_sections, @@ -1388,35 +1434,31 @@ def _split(has_size_vector): return _impl def _unpack(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): input_node = inputs[0] axis = attr['axis'] - input_shape = attr['_input_shapes'][input_node] + input_shape = _infer_shape(input_node, mod) axis_length = input_shape[axis] if axis_length < 0: raise TypeError("Unstack with unknown axis length") splitted = _op.split(input_node, indices_or_sections=axis_length, axis=axis) - #name=attr.get('_node_name', 'unstack')) - if axis == 0: - axis = None - else: - axis = [axis] + axis = [axis] return _expr.TupleWrapper( _expr.Tuple([_op.squeeze(split_item, axis=axis) \ for split_item in splitted]), len(splitted)) return _impl def _softmax(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): return AttrCvt(op_name='softmax', transforms={'axis': ('axis', 1)})([inputs[0]], attr) return _impl def _softplus(): # op description: https://www.tensorflow.org/api_docs/python/tf/math/softplus - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): exp_out = AttrCvt('exp')(inputs, attr) inputs.append(tvm.relay.const(1, attr['T'].name)) rh = tvm.relay.const(1, attr['T'].name) @@ -1425,8 +1467,12 @@ def _softplus(): return _impl def _topk(): - def _impl(inputs, attr, params): - k = int(_get_num_param(params, inputs.pop(1))) + def _impl(inputs, attr, params, mod): + k_input = inputs.pop(1) + try: + k = int(_get_num_param(params, k_input)) + except (IndexError, KeyError, AttributeError): + k = int(_infer_value(k_input, params).asnumpy().tolist()) if k < 1: raise tvm.error.OpAttributeInvalid( 'Attribute k must be positive in operator TopKV2') @@ -1439,28 +1485,39 @@ def _topk(): return _impl def _floordiv(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): assert len(inputs) == 2 return AttrCvt('floor_divide')(inputs, attr) return _impl def _floormod(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): assert len(inputs) == 2 return AttrCvt('floor_mod')(inputs, attr) return _impl def _logical(name): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): return AttrCvt(op_name=name)(inputs, attr) return _impl def _space_to_batch_nd(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): input_node = inputs[0] - input_shape = attr['_input_shapes'][input_node] - block_shape = _get_list_param(params, inputs[1]) - paddings = _get_list_param(params, inputs[2]) + input_shape = _infer_shape(input_node, mod) + try: + block_shape = _get_list_param(params, inputs[1]) + except (IndexError, KeyError, AttributeError): + block_shape = _infer_value(inputs[1], params).asnumpy().tolist() + + try: + paddings = _get_list_param(params, inputs[2]) + except (IndexError, KeyError, AttributeError): + paddings = _infer_value(inputs[2], params).asnumpy() + paddings = np.squeeze(paddings) + if len(paddings.shape) == 1: + paddings = np.expand_dims(paddings, exis=0) + paddings = paddings.tolist() N = len(input_shape) M = len(block_shape) batch = input_shape[0] @@ -1495,18 +1552,29 @@ def _space_to_batch_nd(): def _batch_to_space_nd(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): input_node = inputs[0] - input_shape = attr['_input_shapes'][input_node] - block_shape = _get_list_param(params, inputs[1]) - crops = _get_list_param(params, inputs[2]) + input_shape = _infer_shape(input_node, mod) + try: + block_shape = _get_list_param(params, inputs[1]) + except (IndexError, KeyError, AttributeError): + block_shape = _infer_value(inputs[1], params).asnumpy().tolist() + + try: + crops = _get_list_param(params, inputs[2]) + except (IndexError, KeyError, AttributeError): + crops = _infer_value(inputs[2], params).asnumpy() + crops = np.squeeze(crops) + if len(crops.shape) == 1: + crops = np.expand_dims(crops, axis=0) + crops = crops.tolist() M = len(block_shape) batch = input_shape[0] # From https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d: # Reshape input to reshaped of shape: # [block_shape[0], ..., block_shape[M-1], batch / prod(block_shape), # input_shape[1], ..., input_shape[N-1]] - shape1 = block_shape + [batch // np.prod(block_shape)] + input_shape[1:] + shape1 = block_shape + [batch // np.prod(block_shape)] + list(input_shape[1:]) reshaped = tvm.relay.reshape(input_node, newshape=shape1) # Permute dimensions of reshaped to produce permuted of shape # [batch / prod(block_shape), input_shape[1], block_shape[0], ..., @@ -1541,13 +1609,13 @@ def _batch_to_space_nd(): return _impl def _atan2(): - def _impl(inputs, attr, params): - divide = _elemwise("divide")(inputs, attr, params) + def _impl(inputs, attr, params, mod): + divide = _elemwise("divide")(inputs, attr, params, mod) return get_relay_op("atan")(divide) return _impl def _prod(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): axis = _get_num_param(params, inputs[1]) keepdims = attr['keep_dims'] return _op.prod(inputs[0], int(axis), keepdims=keepdims) @@ -1555,21 +1623,21 @@ def _prod(): def _log1p(): # op description: https://www.tensorflow.org/api_docs/python/tf/math/log1p - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): one = tvm.relay.const(1, attr['T'].name) add_out = get_relay_op('add')(inputs[0], one) return get_relay_op('log')(add_out) return _impl def _one_hot(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): depth = int(_get_num_param(params, inputs[1])) dtype = attr['T'].name on_value = _get_num_param(params, inputs[2]) off_value = _get_num_param(params, inputs[3]) - new_inputs = [inputs[0], \ - tvm.relay.const(on_value, dtype), \ + new_inputs = [inputs[0], + tvm.relay.const(on_value, dtype), tvm.relay.const(off_value, dtype)] return AttrCvt('one_hot', ignores=['TI'], @@ -1577,20 +1645,20 @@ def _one_hot(): return _impl def _squared_difference(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): difference = _op.subtract(inputs[0], inputs[1]) return _op.multiply(difference, difference) return _impl def _size(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): new_attr = attr new_attr['out_type'] = attr['out_type'].name return AttrCvt('ndarray_size', transforms={'out_type' : 'dtype'})(inputs, new_attr) return _impl def _add_n(): - def _impl(inputs, attr, params): + def _impl(inputs, attr, params, mod): if not isinstance(inputs, tuple): inputs = list(inputs) assert len(inputs) > 0, "add_n take >=1 inputs, but 0 given." @@ -1758,7 +1826,7 @@ _convert_map = { } def _LSTMBlockCell(): - def _impl(inputs, in_state_c, in_state_h, attr, params): + def _impl(inputs, in_state_c, in_state_h, attr, params, mod): """LSTM Block cell. Calculations are described in: https://github.com/tensorflow/tensorflow/blob/ r1.8/tensorflow/contrib/rnn/python/ops/lstm_ops.py#L41-L114 @@ -1787,8 +1855,8 @@ def _LSTMBlockCell(): in_weight = inputs[3] in_bias = inputs[7] forget_bias = attr.pop('forget_bias') - input_shape = attr['_input_shapes'][inputs[0]] - weight_shape = attr['_input_shapes'][inputs[3]] + input_shape = _infer_shape(inputs[0], mod) + weight_shape = _infer_shape(inputs[3], mod) batch_size, input_size = input_shape[0], input_shape[1] num_hidden_layers = weight_shape[1] num_hidden = num_hidden_layers // 4 @@ -1883,7 +1951,7 @@ class RecurrentNetworks(object): sym : relay.Expr The returned relay Expr """ - def _impl(op_name, layer_name, inputs, attrs, params, num_layers): + def _impl(op_name, layer_name, inputs, attrs, params, num_layers, mod): in_state_c_name = layer_name+'_c' in_state_h_name = layer_name+'_h' @@ -1914,8 +1982,8 @@ class RecurrentNetworks(object): def _LSTMBlockCellWrapper(inputs, attr, params, num_layers, layer): """LSTM cell warapper to prepare the inputs""" - input_shape = attr['_input_shapes'][inputs[0]] - weight_shape = attr['_input_shapes'][inputs[3]] + input_shape = _infer_shape(inputs[0], mod) + weight_shape = _infer_shape(inputs[3], mod) batch_size = input_shape[0] num_hidden = weight_shape[1] // 4 @@ -1928,13 +1996,13 @@ class RecurrentNetworks(object): in_state_c = self._nodes[in_state_c_name] in_state_h = self._nodes[in_state_h_name] - cur_in_state_c, cur_in_state_h = _get_cur_input_state( \ - in_state_c, in_state_h, - num_layers, layer, - batch_size, num_hidden) + cur_in_state_c, cur_in_state_h = _get_cur_input_state( + in_state_c, in_state_h, + num_layers, layer, + batch_size, num_hidden) output, out_state = self._convert_map[op_name](inputs, cur_in_state_c, cur_in_state_h, - attr, params) + attr, params, mod) return output, out_state, in_state_c, in_state_h sym, cur_out_state, in_state_c, in_state_h = \ @@ -1948,7 +2016,7 @@ class RecurrentNetworks(object): return sym return _impl - def process_op(self, op_name, inputs, attrs, params): + def process_op(self, op_name, inputs, attrs, params, mod): """Process recurrent layer operators. List '_recurrent_ops_layer_map' map each Layer based operators with its @@ -1998,7 +2066,7 @@ class RecurrentNetworks(object): num_layers += 1 sym = self._recurrent_ops_layer_map[op_name](op_name, layer_name, inputs, attrs, - params, num_layers) + params, num_layers, mod) return sym # An internal list to contain all the control flow primitives used in Tensorflow @@ -2051,7 +2119,6 @@ def _in_while_loop(control_flow_node_map, op_name): return op_name in control_flow_node_map and \ "LoopCond" in control_flow_node_map[op_name] - class Branch: """A class contains the components that are used to build up a Relay if node. @@ -2133,6 +2200,82 @@ class Branch: return self._if +class LoopBound(ExprVisitor): + """ + When a loop body is create, we get a Relay expression backtracing all + the way back to input node. This will result in lots of unnecessary + expression placed into loop body and compute multiple times. For example, + consider the following tensorflow code: + + .. code-block:: python + + i = tf.constant(0) + data = tf.compat.v1.placeholder(tf.float32, shape=(1024, 1024)) + slice = tf.strided_slice(data, 0, 512) + def c(i): return tf.less(i, 10) + def b(i): return [tf.add(i, 1), tf.add(i, 1) + slice] + r = tf.while_loop(c, b, [i]) + + If we directly create recursive function, slice will be placed into function body. + Instead, we recognize whether slice is inside while_loop block and pass it as an + extra loop variable to avoid duplicate computation. + + TODO(kevinthesun): Add a LICM pass for Relay to handle generic loop/function. + """ + def __init__(self, loop_name, hash2tfnode, while_loop_name_set): + ExprVisitor.__init__(self) + self._loop_name = loop_name + self._hash2tfnode = hash2tfnode + self._while_loop_name_set = while_loop_name_set + self.extra_loop_var_names = set() + + def _find_parent_loop_name(self, node_name): + """Find name of direct parent while loop.""" + ploop_name = "" + name_prefix = node_name.rsplit('/', 1)[0] + if name_prefix.startswith("^"): + name_prefix = name_prefix[1:] + # To get the name of the direct parent while loop for a given node, + # we iterate all the while loop names inside TensorFlow graph def. + # If we find a loop name with which current node name starts, + # it means current node is under this loop. However, due to nested + # loop, this loop may not be the direct parent while loop of current + # node. We need to keep the longest loop name, which represents the + # innermost while loop corresponding to current node. + for lname in self._while_loop_name_set: + if name_prefix.startswith(lname) and len(ploop_name) < len(lname): + ploop_name = lname + + if len(ploop_name) == 0: + ploop_name = name_prefix + + return ploop_name + + def visit(self, expr): + """ + For each expression in the body, look up the corresponding + TensorFlow node with its structural hash. If the current loop is the + direct parent of this node, we check whether its every input node belongs + to the current loop. If not, we mark this input node as an extra loop + variable to the current loop. + """ + expr_hash = s_hash(expr) + + if expr_hash in self._hash2tfnode: + node = self._hash2tfnode[expr_hash] + ploop_name = self._find_parent_loop_name(node.name) + # It is possibel that a node is under nested loop of current loop. + # We only check the direct children of current loop. + if ploop_name == self._loop_name: + for iname in node.input: + iploop_name = self._find_parent_loop_name(iname) + # Use startswith to deal with nested loop + if not iploop_name.startswith(self._loop_name): + if iname not in self.extra_loop_var_names: + self.extra_loop_var_names.add(iname) + super().visit(expr) + + class Loop: """ A class contains the components that are used to build up a Relay @@ -2189,11 +2332,18 @@ class Loop: %6 } """ - def __init__(self): + def __init__(self, mod, loop_name, hash2tfnode, + node_map, while_loop_name_set): self.loop_vars = [] self.cond = None self.body = [] self._loop = None + self._mod = mod + self._loop_name = loop_name + self._hash2tfnode = hash2tfnode + self._node_map = node_map + self._while_loop_name_set = while_loop_name_set + self.aligned = False def _while_loop(self): """An internal API to create a Relay recursive call for a matched TF @@ -2203,11 +2353,30 @@ class Loop: sb = tvm.relay.scope_builder.ScopeBuilder() + loop_checker = LoopBound(self._loop_name, + self._hash2tfnode, + self._while_loop_name_set) + for body in self.body: + loop_checker.visit(body) + loop_vars = [] bind_map = {} + loop_var_hash_set = set() + for var in self.loop_vars: + loop_var_hash_set.add(s_hash(var)) + + extra_nodes = [] + for extra_loop_var_name in loop_checker.extra_loop_var_names: + extra_loop_var_name = extra_loop_var_name.split(':')[0].split("^")[-1] + extra_node = self._node_map[extra_loop_var_name] + extra_node = extra_node if isinstance(extra_node, _expr.Tuple) else extra_node[0] + if s_hash(extra_node) not in loop_var_hash_set: + self.loop_vars.append(extra_node) + extra_nodes.append(extra_node) + for i, var in enumerate(self.loop_vars): if not isinstance(var, _expr.Var): - var_chk = _infer_type(var) + var_chk = _infer_type(var, self._mod) var_type = var_chk.checked_type else: var_type = var.type_annotation @@ -2216,21 +2385,37 @@ class Loop: loop_vars.append(v) bind_map[var] = v + self.cond = rewrite_subgraph(self.cond, bind_map) self.body = [rewrite_subgraph(b, bind_map) for b in self.body] + self.body_shape = [] + for body in self.body: + current_node = body + shape = _infer_shape(current_node, self._mod) + while not isinstance(shape, (tuple, list)): + current_node = current_node.args[-1] + shape = _infer_shape(current_node, self._mod) + self.body_shape.append(shape) + cond = tvm.relay.op.min(self.cond) with sb.if_scope(cond): - sb.ret(wl(*self.body)) + extra_args = [] + if extra_nodes: + extra_args = list(loop_vars[-len(extra_nodes):]) + sb.ret(wl(*list(self.body + extra_args))) with sb.else_scope(): sb.ret(tvm.relay.Tuple(loop_vars)) loop_fn = tvm.relay.Function(loop_vars, sb.get()) sb = tvm.relay.scope_builder.ScopeBuilder() sb.let(wl, loop_fn) - sb.ret(wl(*self.loop_vars)) - return sb.get() + loop_ret = wl(*self.loop_vars) + + sb.ret(loop_ret) + ret = sb.get() + return ret def while_loop(self): """Instantiate a while loop if it has not been created yet.""" @@ -2247,16 +2432,21 @@ class GraphProto(object): """ def __init__(self): self._nodes = {} + self._tf_node_map = {} self._params = {} self._input_shapes = {} self._output_shapes = {} - self._num_param = 0 self._num_rnn_layer = False self._input_shapes = {} self._loops = {} self._branches = {} self._mod = IRModule({}) self._prelude = Prelude(self._mod) + self._control_flow_node_map = defaultdict(set) + self._loop_body_order = {} + self._loop_var_order = {} + self._hash2tfnode = {} + self._while_loop_name_set = set() def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): """Construct relay nodes from tensorflow graph definition - GraphDef. @@ -2296,7 +2486,6 @@ class GraphProto(object): params : dict A dict of name: tvm.nd.array pairs, used as pretrained weights """ - try: from tensorflow.python.framework import tensor_util except ImportError as e: @@ -2304,6 +2493,10 @@ class GraphProto(object): "Unable to import tensorflow which is required {}".format(e)) missing_operators = self._parse_import_prerequisites(graph) + control_flow_nodes = [] + self._in_shape = shape + self._layout = layout + self._graph = graph if missing_operators: freezed_ops = [op for op in missing_operators if op in _freezed_graph_pruned_op_list] @@ -2311,13 +2504,24 @@ class GraphProto(object): raise Exception("Graph is not frozen. Provide a frozen graph. " "Found operators {}".format(freezed_ops)) - raise NotImplementedError( \ + raise NotImplementedError( "The following operators are not implemented: {}".format(missing_operators)) - control_flow_node_map = defaultdict(set) for node in graph.node: node_name_prefix = node.name.rsplit('/', 1)[0] - control_flow_node_map[node_name_prefix].add(node.op) + self._control_flow_node_map[node_name_prefix].add(node.op) + self._tf_node_map[node.name] = node + + # Parse output_shapes attribute + parsed_attr = self._parse_attr(node.attr) + if '_output_shapes' in parsed_attr: + self._output_shapes[node.name] = \ + [tensor_util.TensorShapeProtoToList(tshape) \ + for tshape in parsed_attr['_output_shapes']] + else: + self._output_shapes[node.name] = [None] + + # Parse placeholder and const here since input shape info is required. if node.op == 'Placeholder' or node.op == 'PlaceholderWithDefault': # Give priority to user argument. if shape and node.name in shape: @@ -2342,120 +2546,53 @@ class GraphProto(object): tensor_value = node.attr['value'].tensor self._input_shapes[node.name] = \ tensor_util.TensorShapeProtoToList(tensor_value.tensor_shape) + self._output_shapes[node.name] = [self._input_shapes[node.name]] if shape and node.name in shape: warnings.warn("Ignore the passed shape. Shape in graphdef " "will be used for operator %s." % node.name) - - # Parse the nodes to re-create TF graph using Relay operators. - for node in graph.node: - # Tensorflow doesn't have separate list for params extraction. - # Operator name 'Const' is treated as a parameter to build params dict. - - input_shapes = {} - attr = self._parse_attr(node.attr) - - # Variable converted to Const will not have only value attr - if 'value' in attr and node.op == 'Const': - self._output_shapes[node.name] = [self._input_shapes[node.name]] - elif '_output_shapes' in attr: - self._output_shapes[node.name] = \ - [tensor_util.TensorShapeProtoToList(tshape) \ - for tshape in attr['_output_shapes']] - else: - # Keep the list indexable to avoid key error. - # Actual value will be filled after node creation. - # Will infer shapes if the graph is not frozen with add_shapes=True - self._output_shapes[node.name] = [None] - - if node.op == "Const": - # All Const nodes are Param nodes, lets parse - self._num_param += 1 for key, value in node.attr.items(): - self._parse_param(key, value, node.name, shape) - if node.name not in self._nodes: - raise NotImplementedError( \ - "Const {} couldn't be converted to Param.".format(node.name)) - - attr = self._parse_attr(node.attr) - - elif node.op != "Placeholder" and node.op != 'PlaceholderWithDefault': - # Pass the parsed shapes instead - attr["_output_shapes"] = output_shapes = self._output_shapes[node.name] - - # Pass the node name too in attr - attr["_node_name"] = node.name - - # Pass the target layout - attr["_target_layout"] = layout - - # Fill shapes for all inputs in a list - inputs = [] - for i in node.input: - # Some TensorFlow operators internally maintain execution layers - # and their output name includes the layer number along with - # graph node name. E.g. the node name is 'Model/RNN/cell_0/RnnCell', but the - # output tensor name is 'Model/RNN/cell_0/RnnCell:0'. In this case, - # the number has to be ignored for single-output nodes. - # On the other hand, for multi-output nodes the number is the output index, - # and the lack of the number implies 0. - tensor_name = i.split(':') - node_name = tensor_name[0] - if node_name in self._nodes: - in_sym = self._nodes[node_name] - if isinstance(in_sym, _expr.TupleWrapper): - tensor_slot = int(tensor_name[1]) if len(tensor_name) > 1 else 0 - in_sym = [in_sym[tensor_slot]] - input_shape = self._output_shapes[node_name][tensor_slot] - else: - tensor_slot = 0 - input_shape = self._output_shapes[node_name][0] - inputs.append(in_sym[0]) - input_shapes[in_sym[0]] = input_shape - - attr['_input_shapes'] = input_shapes - - if node.op in _control_flow_nodes: - op = self._convert_control_flow_operator(node, inputs, - attr, - control_flow_node_map) - else: - op = self._convert_operator(node.op, inputs, attr, graph) - - # Check if op is converted to param - if isinstance(op, np.ndarray): - self._params[node.name] = tvm.nd.array(op) - op = [_expr.var(node.name, - shape=self._params[node.name].shape, - dtype=self._params[node.name].dtype)] - - elif isinstance(op, (_expr.TupleWrapper, tuple, list)): - pass - elif isinstance(op, _expr.Expr): - op = [op] - else: - raise RuntimeError("unexpected type %s" % type(op)) + self._parse_param(key, value, node.name, self._in_shape) + elif node.op in _control_flow_nodes: + # We assume that the direct parent node of Exit is a while loop block + if node.op == "Exit": + self._while_loop_name_set.add(node_name_prefix) + control_flow_nodes.append(node) + + # First, parse all control flow nodes. + # Convert tf.cond to Branch and tf.while_loop to Loop. + sorted_cf_nodes = [] + current_node_name_prefix = None + exits = [] + # Sort control flow nodes to move all Exit nodes to the end + # of corresponding while_loop block. + for i, node in enumerate(control_flow_nodes): + node_name_prefix = node.name.rsplit('/', 1)[0] + if current_node_name_prefix is None or current_node_name_prefix != node_name_prefix: + if node_name_prefix in self._while_loop_name_set: + sorted_cf_nodes.extend(exits) + exits.clear() + current_node_name_prefix = node_name_prefix - self._nodes[node.name] = op + if node.op == "Exit": + exits.append(node) + else: + sorted_cf_nodes.append(node) - # Infer shapes even without specifying "add_shapes=True" - if output_shapes == [None]: - out_shapes = [_infer_shape(node_item, self._mod) - for node_item in self._nodes[node.name]] - self._output_shapes[node.name] = out_shapes + if i == len(control_flow_nodes) - 1: + sorted_cf_nodes.extend(exits) - if self._output_shapes[node.name] and shape and node.name in shape: - assert self._output_shapes[node.name] == list(shape[node.name]) + for node in sorted_cf_nodes: + self._backtrack_construct(node.name) - # Infer shapes if passed explicitly - node_output = self._nodes[node.name] - if shape and (not self._output_shapes[node.name][0] - or -1 in self._output_shapes[node.name][0]): - out_shapes = [_infer_shape(node_item, self._mod) for node_item in node_output] - self._output_shapes[node.name] = out_shapes + # Second, parse other nodes to re-create TF graph using Relay operators. + for node in graph.node: + self._backtrack_construct(node.name) out = [] if outputs is None: - if node.op == "Exit": + last_node = graph.node[-1] + op = self._nodes[last_node.name.split(":")[0]] + if last_node.op == "Exit": out = [op[0].tuple_value] else: out = op @@ -2620,7 +2757,7 @@ class GraphProto(object): self._out_rnn = [] self.rnn = RecurrentNetworks(self._nodes, self._out_rnn, graph, convert_map) self._num_rnn_layer = True - sym = self.rnn.process_op(op_name, inputs, attrs, params) + sym = self.rnn.process_op(op_name, inputs, attrs, params, self._mod) return sym def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_map): @@ -2651,53 +2788,95 @@ class GraphProto(object): """ node_name_prefix = node.name.rsplit('/', 1)[0] if node.op == "Merge": - if _in_while_loop(control_flow_node_map, node_name_prefix): - op = self._nodes[node.input[0]] - self._loops[node_name_prefix] = Loop() + if _in_while_loop(self._control_flow_node_map, node_name_prefix): + op = self._backtrack_construct(node.input[0]) + if node_name_prefix not in self._loops: + self._loops[node_name_prefix] = Loop(self._mod, + node_name_prefix, + self._hash2tfnode, + self._nodes, + self._while_loop_name_set) else: if len(self._branches) == 0: raise RuntimeError("Cannot find a created " "conditional for merge node") branch = self._branches[node_name_prefix] - false_br = self._nodes[node.input[0]] - true_br = self._nodes[node.input[1]] + false_br = self._backtrack_construct(node.input[0]) + true_br = self._backtrack_construct(node.input[1]) assert len(true_br) == 1 assert len(false_br) == 1 branch.true_branch = true_br[0] branch.false_branch = false_br[0] op = [branch.if_node()] + if node_name_prefix not in self._while_loop_name_set: + try: + cond_val = np.all(_infer_value(branch.cond, self._params, + self._mod).asnumpy()) + if cond_val: + op = [branch.true_branch] + else: + op = [branch.false_branch] + except Exception: + op = [branch.if_node()] elif node.op == "Exit": loop = self._loops[node_name_prefix] - exit_name = node.name.split('/')[-1] - assert str.startswith(exit_name, 'Exit') - # TensorFlow has differen naming convention on different - # versions. + # Check whether the order of loop variables aligns + # with loop body. If not, create new loop variable list + # with correct order. + if not loop.aligned: + loop_vars = [] + for i in self._loop_body_order[node_name_prefix]: + for j, k in enumerate(self._loop_var_order[node_name_prefix]): + if k == i: + loop_vars.append(loop.loop_vars[j]) + loop.loop_vars = loop_vars + loop.aligned = True + exit_name = node.name.split('/')[-1] if '_' in exit_name: - exit_number = int("0" + exit_name[5:]) + exit_number = int(exit_name[5:]) else: - exit_number = int("0" + exit_name[4:]) - + exit_number = 0 expr = loop.while_loop() - op = _expr.TupleGetItem(expr, exit_number) + body_pos = exit_number + for i, j in enumerate(self._loop_body_order[node_name_prefix]): + if exit_number == j: + body_pos = i + break + op = [_expr.TupleGetItem(expr, body_pos)] elif node.op == "Enter": - op = self._nodes[node.input[0]] + op = self._backtrack_construct(node.input[0]) elif node.op == "LoopCond": - op = self._nodes[node.input[0]] + op = self._backtrack_construct(node.input[0]) assert len(op) == 1 self._loops[node_name_prefix].cond = op[0] elif node.op == "Switch": - op = self._nodes[node.input[0]] + op = self._backtrack_construct(node.input[0]) + cond = self._backtrack_construct(node.input[1]) assert len(op) == 1 - if _in_while_loop(control_flow_node_map, node_name_prefix): + if _in_while_loop(self._control_flow_node_map, node_name_prefix): + if node_name_prefix not in self._loop_var_order: + self._loop_var_order[node_name_prefix] = [] + if node.name.endswith("Switch"): + self._loop_var_order[node_name_prefix].append(0) + else: + self._loop_var_order[node_name_prefix].\ + append(int(node.name.split("Switch_")[-1])) self._loops[node_name_prefix].loop_vars.append(op[0]) else: if node_name_prefix not in self._branches: self._branches[node_name_prefix] = Branch() - chk_op = _infer_type(op[0]) - self._branches[node_name_prefix].cond = chk_op + self._branches[node_name_prefix].cond = cond[0] elif node.op == "NextIteration": - op = self._nodes[node.input[0]] + if node_name_prefix not in self._loop_body_order: + self._loop_body_order[node_name_prefix] = [] + if node.name.endswith("NextIteration"): + self._loop_body_order[node_name_prefix].append(0) + else: + self._loop_body_order[node_name_prefix].\ + append(int(node.name.split("NextIteration_")[-1])) + op = self._backtrack_construct(node.input[0]) + assert len(op) == 1 self._loops[node_name_prefix].body.append(op[0]) else: @@ -2706,7 +2885,6 @@ class GraphProto(object): return op - def _convert_operator(self, op_name, inputs, attrs, graph, identity_list=None, convert_map=None): """Convert from Tensorflow operator to relay operator. @@ -2741,10 +2919,8 @@ class GraphProto(object): elif op_name in convert_map: if _need_prelude_for_shape_inference(op_name): sym = convert_map[op_name](inputs, attrs, self._params, self._prelude) - elif _need_module_for_shape_inference(op_name): - sym = convert_map[op_name](inputs, attrs, self._params, self._mod) else: - sym = convert_map[op_name](inputs, attrs, self._params) + sym = convert_map[op_name](inputs, attrs, self._params, self._mod) elif op_name in convert_map_rnn: sym = self._convert_rnn_operator(op_name, inputs, attrs, @@ -2754,6 +2930,67 @@ class GraphProto(object): raise NotImplementedError("Operator {} not implemented.".format(op_name)) return sym + def _backtrack_construct(self, node_name): + """Convert a specific tensorflow node to relay expression. + + If any of its ancestor node is not converted yet, backtrack as + far as input node and covert all nodes on the path. + + This is required when parsing control flow nodes, since the parsing + order may not follow the original graph def. + + Parameters + ---------- + node_name : str + Tensorflow node name. + + Returns + ------- + op : relay.Expr + Converted relay expression + """ + node_name = node_name.split(':')[0].split("^")[-1] + + if node_name not in self._nodes: + node = self._tf_node_map[node_name] + attr = self._parse_attr(node.attr) + + if node.op in _control_flow_nodes: + attr = self._parse_attr(node.attr) + op = self._convert_control_flow_operator(node, [], + attr, + self._control_flow_node_map) + else: + attr["_output_shapes"] = self._output_shapes[node_name] + attr["_node_name"] = node.name + attr["_target_layout"] = self._layout + inputs = [] + for iname in node.input: + in_op = self._backtrack_construct(iname) + if isinstance(in_op, _expr.TupleWrapper): + tn = iname.split(':') + tensor_slot = int(tn[1]) if len(tn) > 1 else 0 + in_op = in_op[tensor_slot] + else: + in_op = in_op[0] + + inputs.append(in_op) + op = self._convert_operator(node.op, inputs, attr, self._graph) + + if isinstance(op, np.ndarray): + self._params[node.name] = tvm.nd.array(op) + op = [_expr.var(node.name, + shape=self._params[node.name].shape, + dtype=self._params[node.name].dtype)] + + elif isinstance(op, (_expr.Expr, _expr.TupleGetItem)): + op = [op] + + node_hash = s_hash(op) if isinstance(op, _expr.Tuple) else s_hash(op[0]) + self._hash2tfnode[node_hash] = node + self._nodes[node_name] = op + + return self._nodes[node_name] def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None): """Load tensorflow graph which is a python tensorflow graph object into relay. diff --git a/tests/python/frontend/tensorflow/test_control_flow.py b/tests/python/frontend/tensorflow/test_control_flow.py index 552b150..9777a8d 100644 --- a/tests/python/frontend/tensorflow/test_control_flow.py +++ b/tests/python/frontend/tensorflow/test_control_flow.py @@ -27,14 +27,16 @@ from tvm import relay from tvm.relay.frontend.tensorflow import from_tensorflow -def check_equal(graph, tf_out): +def check_equal(graph, tf_out, input_map=None): mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True)) + if input_map is not None: + params.update(input_map) ex = relay.create_executor('vm', mod=mod) relay_out = ex.evaluate()(**params) if isinstance(relay_out, nd.NDArray): np.testing.assert_allclose(tf_out, relay_out.asnumpy()) else: - if not isinstance(tf_out, list): + if not isinstance(tf_out, (list, tuple)): tf_out = [tf_out] for x, y in zip(tf_out, [r.asnumpy() for r in relay_out]): np.testing.assert_allclose(x, y) @@ -303,9 +305,70 @@ def test_cond_in_loop(): check_equal(graph, tf_out) +def test_vanilla_loop_bound(): + graph = tf.Graph() + with graph.as_default(): + dshape = (2, 10) + dtype = "float32" + dname = "data" + np_data = np.random.uniform(size=dshape).astype(dtype) + data = tf.placeholder(shape=dshape, dtype=dtype, name=dname) + x = tf.slice(data, [1, 4], [1, 4]) + outer = x + 5.0 + def body(x, y): + res = tf.cond(tf.less(y, 10), lambda: tf.add( + 10.0, 20.0), lambda: tf.square(10.0)) + z = tf.constant(7) + res = tf.cond(tf.less(z, 10), lambda: res * 5, lambda: res + 10) + return tf.multiply(res, x * outer), y + 1 + + y = tf.constant(0) + def condition(x, y): + return tf.less(y, 20) + + r = tf.while_loop(condition, body, loop_vars=[x, y]) + with tf.Session() as sess: + tf_out = sess.run(r, feed_dict={"%s:0" % dname: np_data}) -if __name__ == "__main__": + check_equal(graph, tf_out, {dname: np_data}) +def test_nested_loop_bound(): + graph = tf.Graph() + with graph.as_default(): + dshape = (2, 10) + dtype = "float32" + dname = "data" + np_data = np.random.uniform(size=dshape).astype(dtype) + data = tf.placeholder(shape=dshape, dtype=dtype, name=dname) + x = tf.slice(data, [1, 4], [1, 4]) + outer = x + 5.0 + def body(x, y): + res = tf.cond(tf.less(y, 10), lambda: tf.add( + 10.0, 20.0), lambda: tf.square(10.0)) + def nested_body(nx, ny): + return nx + 1, res + 2.0 + def nested_cond(nx, ny): + return tf.less(nx, 15) + nx = tf.constant(0) + ny = tf.constant(0.0) + nested_res = tf.while_loop(nested_cond, nested_body, loop_vars=[nx, ny]) + res = res + nested_res[1] + z = tf.constant(7) + res = tf.cond(tf.less(z, 10), lambda: res * 5, lambda: res + 10) + return tf.multiply(res, x * outer), y + 1 + + y = tf.constant(0) + def condition(x, y): + return tf.less(y, 20) + + r = tf.while_loop(condition, body, loop_vars=[x, y]) + with tf.Session() as sess: + tf_out = sess.run(r, feed_dict={"%s:0" % dname: np_data}) + + check_equal(graph, tf_out, {dname: np_data}) + + +if __name__ == "__main__": # tf.while_loop test_vanilla_loop() test_loop_2_vars() @@ -325,3 +388,5 @@ if __name__ == "__main__": test_nested_cond() test_loop_in_cond() test_cond_in_loop() + test_vanilla_loop_bound() + test_nested_loop_bound() diff --git a/tests/python/frontend/tensorflow/test_debugging.py b/tests/python/frontend/tensorflow/test_debugging.py index 5ce8dab..01ad6a2 100644 --- a/tests/python/frontend/tensorflow/test_debugging.py +++ b/tests/python/frontend/tensorflow/test_debugging.py @@ -67,13 +67,11 @@ def test_assert_true_var_capture(): x_value = np.random.rand() assert sess.run(assert_op, feed_dict={x: x_value}) is None - # ToDo: The frontend converter gets confused here as well, thinking - # that it needs to be told what x is twice. It also notes the output of + # TODO: The frontend converter notes the output of # the graph as a boolean, which is not correct - as you can see above, - # TF believes that the value of this graph is None. In addition, the - # arity of the translated function should be 1, not 2. + # TF believes that the value of this graph is None. np.testing.assert_allclose(True, - run_relay(g, None, x_value, x_value).asnumpy()) + run_relay(g, None, x_value).asnumpy()) def test_assert_false(): g = tf.Graph() diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 78d504e..9d875c1 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1207,6 +1207,8 @@ def test_forward_stridedslice(): '''test StridedSlice''' _test_stridedslice((2), [1], [1], [1], 'float32', shrink_axis_mask=1) + _test_stridedslice((2, 1), [0], [1], [1], 'float32', shrink_axis_mask=1) + _test_stridedslice((2, 3, 4), [0], [1], [1], 'float32', shrink_axis_mask=8) _test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], 'float32') _test_stridedslice((3, 4, 3), [1, 0], [4, 3], [ -- 2.7.4