[Relay] Better shape inference in TensorFlow Frontend. (#3176)
authorJosh Fromm <jwfromm@uw.edu>
Fri, 17 May 2019 10:41:50 +0000 (03:41 -0700)
committerSiva <sivar.b@huawei.com>
Fri, 17 May 2019 10:41:50 +0000 (16:11 +0530)
* Some bug fixes in tensorflow graph converter and added DepthToSpace operator.

* Made DepthToSpace better comply with other function syntax.

* Added better shape inference for unusual situations.

* Lint fixes.

* Added depthtospace test.

* Added test cases for value inference and depthtospace.

* Added fill testing.

* Made comment changes and added BroadcastTo op and tests.

* Fixed underlining and unneeded opt_level forcing.

* Added _infer_value assertion that all values to infer are available in passed parameters.

python/tvm/relay/frontend/tensorflow.py
tests/python/frontend/tensorflow/test_forward.py

index b5a9ea5..11026b9 100644 (file)
@@ -34,6 +34,20 @@ from ..expr_functor import ExprMutator
 
 __all__ = ['from_tensorflow']
 
+def _infer_value(input_val, params):
+    from tvm.contrib import graph_runtime
+    # Check that all free variables have associated parameters.
+    assert all(var.name_hint in params.keys() for var in ir_pass.free_vars(
+        input_val)), "All inputs to infer must be available in params."
+    func = _expr.Function(ir_pass.free_vars(input_val), input_val)
+    with tvm.relay.build_config(opt_level=0):
+        graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
+    ctx = tvm.context("llvm", 0)
+    m = graph_runtime.create(graph, lib, ctx)
+    m.set_input(**params)
+    m.run()
+    return m.get_output(0)
+
 def _get_relay_op(op_name):
     try:
         op = getattr(_op, op_name)
@@ -465,7 +479,12 @@ def _expand_dims():
 
 def _resize_bilinear():
     def _impl(inputs, attr, params):
-        attr['size'] = attr['_output_shapes'][0][1:3]
+        size = attr['_output_shapes'][0][1:3]
+        # Important that the size is defined. If an axis is not, we need to infer what
+        # the shape should be.
+        if -1 in size:
+            size = _infer_value(inputs[1], params).asnumpy().reshape([-1]).tolist()
+        attr['size'] = size
         inputs.pop(1)
         # NHWC
         attr['layout'] = 'NHWC'
@@ -574,15 +593,7 @@ def _reshape():
         except AttributeError:
             # Shape operator is already pruned, hence
             # try to infer shape by precompute prune if possible.
-            func = _expr.Function(ir_pass.free_vars(inputs[1]), inputs[1])
-            with tvm.relay.build_config(opt_level=0):
-                graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
-            ctx = tvm.context("llvm", 0)
-            from tvm.contrib import graph_runtime
-            m = graph_runtime.create(graph, lib, ctx)
-            m.set_input(**params)
-            m.run()
-            params_new = m.get_output(0)
+            params_new = _infer_value(inputs[1], params)
             inputs.pop(1)
             return AttrCvt(
                 op_name="reshape",
@@ -590,9 +601,63 @@ def _reshape():
                 ignores=['Tshape'])(inputs, attr)
     return _impl
 
+
+def _depth_to_space():
+    def _impl(inputs, attr, params):
+        # Need to handle data layouts differently.
+        input_shape = attr['_input_shapes'][inputs[0]]
+        block_size = int(attr['block_size'])
+        if attr['data_format'].decode("utf-8") == 'NHWC':
+            in_n, in_h, in_w, in_c = input_shape
+            new_c = int(in_c / (block_size * block_size))
+
+            # First expand input to larger dimension.
+            expanded = _op.reshape(
+                inputs[0], newshape=(in_n, in_h, in_w, block_size, block_size, new_c))
+            # Now reorder to expand spatial blocks.
+            transposed = _op.transpose(expanded, axes=(0, 1, 3, 2, 4, 5))
+            # Finally reshape to proper output.
+            new_h = in_h * block_size
+            new_w = in_w * block_size
+            newshape = (in_n, new_h, new_w, new_c)
+
+        else: # Handle NCHW layout
+            in_n, in_c, in_h, in_w = input_shape
+            new_c = int(in_c / (block_size * block_size))
+
+            expanded = _op.reshape(
+                inputs[0], newshape=(in_n, block_size, block_size, new_c, in_h, in_w))
+            transposed = _op.transpose(expanded, axes=(0, 3, 4, 1, 5, 2))
+            new_h = in_h * block_size
+            new_w = in_w * block_size
+            newshape = (in_n, new_c, new_h, new_w)
+
+        return AttrCvt(
+            op_name="reshape",
+            extras={'newshape': newshape},
+            ignores=['data_format', 'block_size'])([transposed], attr)
+
+    return _impl
+
+
 def _bias_add():
     def _impl(inputs, attr, params):
-        return _op.add(inputs[0], inputs[1])
+        # Must expand for proper broadcasting in NCHW.
+        if attr['data_format'].decode("utf-8") == 'NCHW':
+            bias = _op.reshape(inputs[1], newshape=(1, -1, 1, 1))
+        else:
+            bias = inputs[1]
+        return _op.add(inputs[0], bias)
+    return _impl
+
+def _broadcast_to():
+    def _impl(inputs, attr, params):
+        if isinstance(inputs[1], _expr.Var):
+            shape = params[inputs[1].name_hint]
+        else:
+            shape = _infer_value(inputs[1], params)
+        shape = list(shape.asnumpy().reshape([-1]))
+        return _op.broadcast_to(inputs[0], shape)
     return _impl
 
 def _squeeze():
@@ -666,9 +731,15 @@ def _shape():
 
 def _fill():
     def _impl(inputs, attr, params):
+        output_shape = attr['_output_shapes'][0]
+        # Output shape must be defined to avoid errors. If any axis is not, we must
+        # try to compute its shape.
+        if -1 in output_shape:
+            output_shape = _infer_value(inputs[0], params).asnumpy().reshape([-1]).tolist()
+
         fill_arg = params.pop(inputs.pop(1).name_hint)
         return _op.full(tvm.relay.const(fill_arg.asnumpy()[0], attr['T'].name),
-                        attr['_output_shapes'][0], attr['T'].name)
+                        output_shape, attr['T'].name)
     return _impl
 
 def _lrn():
@@ -1115,6 +1186,7 @@ _convert_map = {
     'BatchNormWithGlobalNormalization'  : _batch_norm(),
     'BatchToSpaceND'                    : _batch_to_space_nd(),
     'BiasAdd'                           : _bias_add(),
+    'BroadcastTo'                       : _broadcast_to(),
     'Cast'                              : _cast(),
     'Ceil'                              : AttrCvt('ceil'),
     'CheckNumerics'                     : _check_numerics(),
@@ -1123,6 +1195,7 @@ _convert_map = {
     'Conv2D'                            : _conv('conv'),
     'DecodeJpeg'                        : _decode_image(),
     'DepthwiseConv2dNative'             : _conv('depthwise'),
+    'DepthToSpace'                      : _depth_to_space(),
     'Equal'                             : _broadcast('equal'),
     'Elu'                               : _elu(),
     'Exp'                               : AttrCvt('exp'),
@@ -1158,11 +1231,12 @@ _convert_map = {
     'Prod'                              : _prod(),
     'Range'                             : _range(),
     'Rank'                              : _rank(),
-    'RealDiv'                           : _elemwise('div'),
+    'RealDiv'                           : _elemwise('divide'),
     'Relu'                              : AttrCvt('relu'),
     'Relu6'                             : _relu6(),
     'Reshape'                           : _reshape(),
     'ResizeBilinear'                    : _resize_bilinear(),
+    'ResizeBicubic'                     : _resize_bilinear(),
     'ReverseV2'                         : _reverse_v2(),
     'Round'                             : AttrCvt('round'),
     'Rsqrt'                             : _rsqrt(),
index 90ee758..e4626e0 100644 (file)
@@ -47,7 +47,8 @@ def convert_to_list(x):
         x = [x]
     return x
 
-def run_tvm_graph(graph_def, input_data, input_node, num_output=1, target='llvm', out_names=None):
+def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
+                  target='llvm', out_names=None, opt_level=3):
     """ Generic function to compile on relay and execute on tvm """
     input_data = convert_to_list(input_data)
     input_node = convert_to_list(input_node)
@@ -71,7 +72,7 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, target='llvm'
                                                  layout=layout,
                                                  shape=shape_dict,
                                                  outputs=out_names)
-    with relay.build_config(opt_level=3):
+    with relay.build_config(opt_level=opt_level):
         graph, lib, params = relay.build(sym, target, params=params)
 
     ctx = tvm.context(target, 0)
@@ -85,8 +86,8 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, target='llvm'
     # execute
     m.run()
     # get outputs
-    assert out_names is None or num_output == len(out_names),"out_names: {} num_output: {}".format(
-                                                              out_names, num_output)
+    assert out_names is None or num_output == len(out_names), (
+        "out_names: {} num_output: {}".format(out_names, num_output))
     tvm_output_list = []
     for i in range(0, num_output):
         tvm_output = m.get_output(i)
@@ -111,7 +112,8 @@ def run_tf_graph(sess, input_data, input_node, output_node):
     return output_data
 
 
-def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False, no_gpu=False):
+def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
+                        no_gpu=False, opt_level=3):
     """Generic function to generate and compare tensorflow and TVM output"""
 
     out_name = convert_to_list(out_name)
@@ -142,8 +144,9 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
             if no_gpu and device == 'cuda':
                 continue
 
-            tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, target=device,
-                                       out_names=out_name, num_output=len(out_name))
+            tvm_output = run_tvm_graph(final_graph_def, in_data, in_node,
+                                       target=device, out_names=out_name,
+                                       num_output=len(out_name), opt_level=opt_level)
             # since the names from tensorflow and relay runs are not exactly same,
             # first len(tf_output) will be compared
             for i in range(len(tf_output)):
@@ -411,6 +414,23 @@ def test_forward_reshape():
     _test_reshape(np.arange(6), [-1])
 
 #######################################################################
+# DepthToSpace
+# ------------
+
+def _test_depthtospace(data, block_size):
+    """ One iteration of depth_to_space operation with given data and block size """
+
+    with tf.Graph().as_default():
+        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
+        array_ops.depth_to_space(in_data, block_size)
+
+        compare_tf_with_tvm(data, 'Placeholder:0', 'DepthToSpace:0')
+
+def test_forward_depthtospace():
+    _test_depthtospace(np.random.normal(size=[1, 32, 32, 4]), 2)
+    _test_depthtospace(np.random.normal(size=[1, 16, 8, 32]), 4)
+
+
 #######################################################################
 # Squeeze
 # -------
@@ -840,16 +860,108 @@ def _test_resize_bilinear(in_shape, to_shape, align_corners):
 
     with tf.Graph().as_default():
         in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
-        shape_data = constant_op.constant(shape_data, shape=shape_data.shape, dtype=shape_data.dtype)
+        shape_data = constant_op.constant(
+            shape_data, shape=shape_data.shape, dtype=shape_data.dtype)
         tf.image.resize_bilinear(in_data, shape_data, align_corners=align_corners)
 
         compare_tf_with_tvm(data, 'Placeholder:0', 'ResizeBilinear:0')
 
+def _test_resize_bilinear_from_tensor(in_shape, align_corners):
+    """ One iteration of resize bilinear with non-constant output shape, requires
+        value inference to get proper output shape."""
+
+    data = np.random.uniform(size=in_shape).astype('float32')
+
+    with tf.Graph().as_default():
+        in_data = array_ops.placeholder(
+            shape=[in_shape[0], in_shape[1], None, None], dtype=data.dtype)
+        to_shape = tf.shape(in_data)[2:]
+        tf.image.resize_bilinear(in_data, to_shape, align_corners=align_corners)
+
+        compare_tf_with_tvm(data, 'Placeholder:0', 'ResizeBilinear:0')
+
 def test_forward_resize_bilinear():
     """ Resize Bilinear """
 
     _test_resize_bilinear((4, 16, 32, 32), [50, 50], False)
     _test_resize_bilinear((6, 32, 64, 64), [20, 20], True)
+    _test_resize_bilinear_from_tensor((4, 16, 32, 32), False)
+    _test_resize_bilinear_from_tensor((6, 32, 50, 50), True)
+
+#######################################################################
+# BroadcastTo
+# -----------
+
+def _test_broadcast_to(in_shape, to_shape):
+    """ One iteration of broadcast_to"""
+
+    data = np.random.uniform(size=in_shape).astype('float32')
+    shape_data = np.array(to_shape).astype('int32')
+
+    with tf.Graph().as_default():
+        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
+        shape_data = constant_op.constant(
+            shape_data, shape=shape_data.shape, dtype=shape_data.dtype)
+        tf.broadcast_to(in_data, shape_data)
+
+        compare_tf_with_tvm(data, 'Placeholder:0', 'BroadcastTo:0', opt_level=0)
+
+
+def _test_broadcast_to_from_tensor(in_shape):
+    """ One iteration of broadcast_to with unknown shape at graph build"""
+
+    data = np.random.uniform(size=in_shape).astype('float32')
+
+    with tf.Graph().as_default():
+        in_data = array_ops.placeholder(
+            shape=[None], dtype=data.dtype)
+
+        shape_data = tf.multiply(tf.shape(in_data), 32)
+        tf.broadcast_to(in_data, shape_data)
+
+        compare_tf_with_tvm(data, 'Placeholder:0', 'BroadcastTo:0')
+
+
+def test_forward_broadcast_to():
+    """ Resize Bilinear """
+
+    _test_broadcast_to((4, 1, 32, 32), [4, 8, 32, 32])
+    _test_broadcast_to((6, 32, 32, 1), [6, 32, 32, 16])
+    _test_broadcast_to_from_tensor((1))
+
+
+#######################################################################
+# Fill
+# ----
+
+def _test_fill(in_shape):
+    """ Use the fill op to create a tensor of ones with non-constant shape."""
+
+    with tf.Graph().as_default():
+        tf.ones(shape=in_shape, dtype='float32')
+        compare_tf_with_tvm(in_shape, [], 'ones:0', opt_level=1)
+
+def _test_fill_from_tensor(in_shape):
+    """ Use the fill op to create a tensor of ones with non-constant shape.
+        Some extra ops need to be added here to prevent the graph from
+        being fully constant and folded away."""
+
+    data = np.random.uniform(size=in_shape).astype('float32')
+
+    with tf.Graph().as_default():
+        in_data = array_ops.placeholder(
+            shape=[in_shape[0], in_shape[1], None, None], dtype=data.dtype)
+
+        x = tf.ones(shape=2*tf.shape(in_data), dtype=data.dtype)
+        y = tf.math.add(in_data, tf.reduce_mean(x), name='out1')
+        compare_tf_with_tvm(data, 'Placeholder:0', 'out1:0')
+
+def test_forward_fill():
+    """ Resize Bilinear """
+
+    _test_fill((32))
+    _test_fill((6, 32, 64, 64))
+    _test_fill_from_tensor((6, 32, 64, 64))
 
 #######################################################################
 # Crop to bounding box
@@ -1567,9 +1679,12 @@ if __name__ == '__main__':
     # Transforms
     test_forward_transpose()
     test_forward_reshape()
+    test_forward_depthtospace()
     test_forward_squeeze()
     test_forward_pack()
     test_forward_resize_bilinear()
+    test_forward_broadcast_to()
+    test_forward_fill()
     test_forward_crop()
     test_forward_pad()
     test_forward_gather()