From: Yao Wang Date: Wed, 17 Jun 2020 17:05:35 +0000 (-0700) Subject: [Frontend][TensorFlow]Fix TF Dynamic input shape (#5825) X-Git-Tag: upstream/0.7.0~539 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=5e28bcdd9a0ed7e230dc6dee5a7e50a580e1f148;p=platform%2Fupstream%2Ftvm.git [Frontend][TensorFlow]Fix TF Dynamic input shape (#5825) * Fix TF Dynamic input shape * Remove warning * Add test --- diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index af09877..62dadce 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -2824,9 +2824,7 @@ class GraphProto(object): tensor_util.TensorShapeProtoToList(node.attr['shape'].shape) for idx, dim in enumerate(self._input_shapes[node.name]): if dim < 0: - self._input_shapes[node.name][idx] = 1 - warnings.warn("Use 1 instead of -1 in shape of operator %s." - % node.name) + self._input_shapes[node.name][idx] = Any() self._output_shapes[node.name] = [self._input_shapes[node.name]] attr = self._parse_attr(node.attr) diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 6f3b7f4..1a0baf8 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -100,14 +100,18 @@ def vmobj_to_list(o): def run_tvm_graph(graph_def, input_data, input_node, num_output=1, target='llvm', out_names=None, opt_level=3, mode='graph_runtime', - cuda_layout="NCHW", layout=None, disabled_pass=None): + cuda_layout="NCHW", layout=None, disabled_pass=None, ignore_in_shape=False): """ Generic function to compile on relay and execute on tvm """ input_data = convert_to_list(input_data) input_node = convert_to_list(input_node) if target == "cuda": layout = cuda_layout target_host = None - shape_dict = {e: i.shape if hasattr(i, "shape") else () for e, i in zip(input_node, input_data)} + if ignore_in_shape: + shape_dict = None + else: + shape_dict = {e: i.shape if hasattr(i, "shape") else () + for e, i in zip(input_node, input_data)} mod, params = relay.frontend.from_tensorflow(graph_def, layout=layout, shape=shape_dict, @@ -3715,6 +3719,33 @@ def test_forward_spop(): _test_spop_variables() _test_spop_constants() +####################################################################### +# Dynamic input shape +# ------------------- +def test_forward_dynamic_input_shape(): + tf.reset_default_graph() + + with tf.Graph().as_default(): + data = tf.placeholder(tf.float32, name='data', shape=(None,)) + out = data + 1 + np_data = np.random.uniform(size=(2,)).astype("float32") + out_name = "add" + + with tf.Session() as sess: + graph_def = tf_testing.AddShapesToGraphDef(sess, out_name) + tf_output = run_tf_graph(sess, np_data, 'data:0', ['{}:0'.format(out_name)]) + # TODO(kevinthesun): enable gpu test when VM heterogeneous execution is ready. + for device in ["llvm"]: + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + continue + tvm_output = run_tvm_graph(graph_def, np_data, ["data"], 1, + target=device, layout="NCHW", out_names=[out_name], + mode="vm", ignore_in_shape=True) + tvm.testing.assert_allclose(tvm_output[0], tf_output[0], + rtol=1e-5, atol=1e-5) + ####################################################################### # Main @@ -3851,3 +3882,6 @@ if __name__ == '__main__': # StatefulPartitionedCall test_forward_spop() + + # Test dynamic input shape + test_forward_dynamic_input_shape()