From 246b41092981ccc4675cf8c97636745e9e59e370 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 17 May 2019 03:41:50 -0700 Subject: [PATCH] [Relay] Better shape inference in TensorFlow Frontend. (#3176) * Some bug fixes in tensorflow graph converter and added DepthToSpace operator. * Made DepthToSpace better comply with other function syntax. * Added better shape inference for unusual situations. * Lint fixes. * Added depthtospace test. * Added test cases for value inference and depthtospace. * Added fill testing. * Made comment changes and added BroadcastTo op and tests. * Fixed underlining and unneeded opt_level forcing. * Added _infer_value assertion that all values to infer are available in passed parameters. --- python/tvm/relay/frontend/tensorflow.py | 100 ++++++++++++++--- tests/python/frontend/tensorflow/test_forward.py | 131 +++++++++++++++++++++-- 2 files changed, 210 insertions(+), 21 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index b5a9ea5..11026b9 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -34,6 +34,20 @@ from ..expr_functor import ExprMutator __all__ = ['from_tensorflow'] +def _infer_value(input_val, params): + from tvm.contrib import graph_runtime + # Check that all free variables have associated parameters. + assert all(var.name_hint in params.keys() for var in ir_pass.free_vars( + input_val)), "All inputs to infer must be available in params." + func = _expr.Function(ir_pass.free_vars(input_val), input_val) + with tvm.relay.build_config(opt_level=0): + graph, lib, params = tvm.relay.build(func, target="llvm", params=params) + ctx = tvm.context("llvm", 0) + m = graph_runtime.create(graph, lib, ctx) + m.set_input(**params) + m.run() + return m.get_output(0) + def _get_relay_op(op_name): try: op = getattr(_op, op_name) @@ -465,7 +479,12 @@ def _expand_dims(): def _resize_bilinear(): def _impl(inputs, attr, params): - attr['size'] = attr['_output_shapes'][0][1:3] + size = attr['_output_shapes'][0][1:3] + # Important that the size is defined. If an axis is not, we need to infer what + # the shape should be. + if -1 in size: + size = _infer_value(inputs[1], params).asnumpy().reshape([-1]).tolist() + attr['size'] = size inputs.pop(1) # NHWC attr['layout'] = 'NHWC' @@ -574,15 +593,7 @@ def _reshape(): except AttributeError: # Shape operator is already pruned, hence # try to infer shape by precompute prune if possible. - func = _expr.Function(ir_pass.free_vars(inputs[1]), inputs[1]) - with tvm.relay.build_config(opt_level=0): - graph, lib, params = tvm.relay.build(func, target="llvm", params=params) - ctx = tvm.context("llvm", 0) - from tvm.contrib import graph_runtime - m = graph_runtime.create(graph, lib, ctx) - m.set_input(**params) - m.run() - params_new = m.get_output(0) + params_new = _infer_value(inputs[1], params) inputs.pop(1) return AttrCvt( op_name="reshape", @@ -590,9 +601,63 @@ def _reshape(): ignores=['Tshape'])(inputs, attr) return _impl + +def _depth_to_space(): + def _impl(inputs, attr, params): + # Need to handle data layouts differently. + input_shape = attr['_input_shapes'][inputs[0]] + block_size = int(attr['block_size']) + if attr['data_format'].decode("utf-8") == 'NHWC': + in_n, in_h, in_w, in_c = input_shape + new_c = int(in_c / (block_size * block_size)) + + # First expand input to larger dimension. + expanded = _op.reshape( + inputs[0], newshape=(in_n, in_h, in_w, block_size, block_size, new_c)) + # Now reorder to expand spatial blocks. + transposed = _op.transpose(expanded, axes=(0, 1, 3, 2, 4, 5)) + # Finally reshape to proper output. + new_h = in_h * block_size + new_w = in_w * block_size + newshape = (in_n, new_h, new_w, new_c) + + else: # Handle NCHW layout + in_n, in_c, in_h, in_w = input_shape + new_c = int(in_c / (block_size * block_size)) + + expanded = _op.reshape( + inputs[0], newshape=(in_n, block_size, block_size, new_c, in_h, in_w)) + transposed = _op.transpose(expanded, axes=(0, 3, 4, 1, 5, 2)) + new_h = in_h * block_size + new_w = in_w * block_size + newshape = (in_n, new_c, new_h, new_w) + + return AttrCvt( + op_name="reshape", + extras={'newshape': newshape}, + ignores=['data_format', 'block_size'])([transposed], attr) + + return _impl + + def _bias_add(): def _impl(inputs, attr, params): - return _op.add(inputs[0], inputs[1]) + # Must expand for proper broadcasting in NCHW. + if attr['data_format'].decode("utf-8") == 'NCHW': + bias = _op.reshape(inputs[1], newshape=(1, -1, 1, 1)) + else: + bias = inputs[1] + return _op.add(inputs[0], bias) + return _impl + +def _broadcast_to(): + def _impl(inputs, attr, params): + if isinstance(inputs[1], _expr.Var): + shape = params[inputs[1].name_hint] + else: + shape = _infer_value(inputs[1], params) + shape = list(shape.asnumpy().reshape([-1])) + return _op.broadcast_to(inputs[0], shape) return _impl def _squeeze(): @@ -666,9 +731,15 @@ def _shape(): def _fill(): def _impl(inputs, attr, params): + output_shape = attr['_output_shapes'][0] + # Output shape must be defined to avoid errors. If any axis is not, we must + # try to compute its shape. + if -1 in output_shape: + output_shape = _infer_value(inputs[0], params).asnumpy().reshape([-1]).tolist() + fill_arg = params.pop(inputs.pop(1).name_hint) return _op.full(tvm.relay.const(fill_arg.asnumpy()[0], attr['T'].name), - attr['_output_shapes'][0], attr['T'].name) + output_shape, attr['T'].name) return _impl def _lrn(): @@ -1115,6 +1186,7 @@ _convert_map = { 'BatchNormWithGlobalNormalization' : _batch_norm(), 'BatchToSpaceND' : _batch_to_space_nd(), 'BiasAdd' : _bias_add(), + 'BroadcastTo' : _broadcast_to(), 'Cast' : _cast(), 'Ceil' : AttrCvt('ceil'), 'CheckNumerics' : _check_numerics(), @@ -1123,6 +1195,7 @@ _convert_map = { 'Conv2D' : _conv('conv'), 'DecodeJpeg' : _decode_image(), 'DepthwiseConv2dNative' : _conv('depthwise'), + 'DepthToSpace' : _depth_to_space(), 'Equal' : _broadcast('equal'), 'Elu' : _elu(), 'Exp' : AttrCvt('exp'), @@ -1158,11 +1231,12 @@ _convert_map = { 'Prod' : _prod(), 'Range' : _range(), 'Rank' : _rank(), - 'RealDiv' : _elemwise('div'), + 'RealDiv' : _elemwise('divide'), 'Relu' : AttrCvt('relu'), 'Relu6' : _relu6(), 'Reshape' : _reshape(), 'ResizeBilinear' : _resize_bilinear(), + 'ResizeBicubic' : _resize_bilinear(), 'ReverseV2' : _reverse_v2(), 'Round' : AttrCvt('round'), 'Rsqrt' : _rsqrt(), diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 90ee758..e4626e0 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -47,7 +47,8 @@ def convert_to_list(x): x = [x] return x -def run_tvm_graph(graph_def, input_data, input_node, num_output=1, target='llvm', out_names=None): +def run_tvm_graph(graph_def, input_data, input_node, num_output=1, + target='llvm', out_names=None, opt_level=3): """ Generic function to compile on relay and execute on tvm """ input_data = convert_to_list(input_data) input_node = convert_to_list(input_node) @@ -71,7 +72,7 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, target='llvm' layout=layout, shape=shape_dict, outputs=out_names) - with relay.build_config(opt_level=3): + with relay.build_config(opt_level=opt_level): graph, lib, params = relay.build(sym, target, params=params) ctx = tvm.context(target, 0) @@ -85,8 +86,8 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, target='llvm' # execute m.run() # get outputs - assert out_names is None or num_output == len(out_names),"out_names: {} num_output: {}".format( - out_names, num_output) + assert out_names is None or num_output == len(out_names), ( + "out_names: {} num_output: {}".format(out_names, num_output)) tvm_output_list = [] for i in range(0, num_output): tvm_output = m.get_output(i) @@ -111,7 +112,8 @@ def run_tf_graph(sess, input_data, input_node, output_node): return output_data -def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False, no_gpu=False): +def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False, + no_gpu=False, opt_level=3): """Generic function to generate and compare tensorflow and TVM output""" out_name = convert_to_list(out_name) @@ -142,8 +144,9 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False, if no_gpu and device == 'cuda': continue - tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, target=device, - out_names=out_name, num_output=len(out_name)) + tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, + target=device, out_names=out_name, + num_output=len(out_name), opt_level=opt_level) # since the names from tensorflow and relay runs are not exactly same, # first len(tf_output) will be compared for i in range(len(tf_output)): @@ -411,6 +414,23 @@ def test_forward_reshape(): _test_reshape(np.arange(6), [-1]) ####################################################################### +# DepthToSpace +# ------------ + +def _test_depthtospace(data, block_size): + """ One iteration of depth_to_space operation with given data and block size """ + + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + array_ops.depth_to_space(in_data, block_size) + + compare_tf_with_tvm(data, 'Placeholder:0', 'DepthToSpace:0') + +def test_forward_depthtospace(): + _test_depthtospace(np.random.normal(size=[1, 32, 32, 4]), 2) + _test_depthtospace(np.random.normal(size=[1, 16, 8, 32]), 4) + + ####################################################################### # Squeeze # ------- @@ -840,16 +860,108 @@ def _test_resize_bilinear(in_shape, to_shape, align_corners): with tf.Graph().as_default(): in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) - shape_data = constant_op.constant(shape_data, shape=shape_data.shape, dtype=shape_data.dtype) + shape_data = constant_op.constant( + shape_data, shape=shape_data.shape, dtype=shape_data.dtype) tf.image.resize_bilinear(in_data, shape_data, align_corners=align_corners) compare_tf_with_tvm(data, 'Placeholder:0', 'ResizeBilinear:0') +def _test_resize_bilinear_from_tensor(in_shape, align_corners): + """ One iteration of resize bilinear with non-constant output shape, requires + value inference to get proper output shape.""" + + data = np.random.uniform(size=in_shape).astype('float32') + + with tf.Graph().as_default(): + in_data = array_ops.placeholder( + shape=[in_shape[0], in_shape[1], None, None], dtype=data.dtype) + to_shape = tf.shape(in_data)[2:] + tf.image.resize_bilinear(in_data, to_shape, align_corners=align_corners) + + compare_tf_with_tvm(data, 'Placeholder:0', 'ResizeBilinear:0') + def test_forward_resize_bilinear(): """ Resize Bilinear """ _test_resize_bilinear((4, 16, 32, 32), [50, 50], False) _test_resize_bilinear((6, 32, 64, 64), [20, 20], True) + _test_resize_bilinear_from_tensor((4, 16, 32, 32), False) + _test_resize_bilinear_from_tensor((6, 32, 50, 50), True) + +####################################################################### +# BroadcastTo +# ----------- + +def _test_broadcast_to(in_shape, to_shape): + """ One iteration of broadcast_to""" + + data = np.random.uniform(size=in_shape).astype('float32') + shape_data = np.array(to_shape).astype('int32') + + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + shape_data = constant_op.constant( + shape_data, shape=shape_data.shape, dtype=shape_data.dtype) + tf.broadcast_to(in_data, shape_data) + + compare_tf_with_tvm(data, 'Placeholder:0', 'BroadcastTo:0', opt_level=0) + + +def _test_broadcast_to_from_tensor(in_shape): + """ One iteration of broadcast_to with unknown shape at graph build""" + + data = np.random.uniform(size=in_shape).astype('float32') + + with tf.Graph().as_default(): + in_data = array_ops.placeholder( + shape=[None], dtype=data.dtype) + + shape_data = tf.multiply(tf.shape(in_data), 32) + tf.broadcast_to(in_data, shape_data) + + compare_tf_with_tvm(data, 'Placeholder:0', 'BroadcastTo:0') + + +def test_forward_broadcast_to(): + """ Resize Bilinear """ + + _test_broadcast_to((4, 1, 32, 32), [4, 8, 32, 32]) + _test_broadcast_to((6, 32, 32, 1), [6, 32, 32, 16]) + _test_broadcast_to_from_tensor((1)) + + +####################################################################### +# Fill +# ---- + +def _test_fill(in_shape): + """ Use the fill op to create a tensor of ones with non-constant shape.""" + + with tf.Graph().as_default(): + tf.ones(shape=in_shape, dtype='float32') + compare_tf_with_tvm(in_shape, [], 'ones:0', opt_level=1) + +def _test_fill_from_tensor(in_shape): + """ Use the fill op to create a tensor of ones with non-constant shape. + Some extra ops need to be added here to prevent the graph from + being fully constant and folded away.""" + + data = np.random.uniform(size=in_shape).astype('float32') + + with tf.Graph().as_default(): + in_data = array_ops.placeholder( + shape=[in_shape[0], in_shape[1], None, None], dtype=data.dtype) + + x = tf.ones(shape=2*tf.shape(in_data), dtype=data.dtype) + y = tf.math.add(in_data, tf.reduce_mean(x), name='out1') + compare_tf_with_tvm(data, 'Placeholder:0', 'out1:0') + +def test_forward_fill(): + """ Resize Bilinear """ + + _test_fill((32)) + _test_fill((6, 32, 64, 64)) + _test_fill_from_tensor((6, 32, 64, 64)) ####################################################################### # Crop to bounding box @@ -1567,9 +1679,12 @@ if __name__ == '__main__': # Transforms test_forward_transpose() test_forward_reshape() + test_forward_depthtospace() test_forward_squeeze() test_forward_pack() test_forward_resize_bilinear() + test_forward_broadcast_to() + test_forward_fill() test_forward_crop() test_forward_pad() test_forward_gather() -- 2.7.4