From cf0a7e28a8d20aec0b88bb79cb40c26a60d33723 Mon Sep 17 00:00:00 2001 From: zhengdi Date: Sun, 8 Mar 2020 14:27:56 +0800 Subject: [PATCH] [FRONTEND][TENSORFLOW] support multiply outputs (#4980) * [FRONTEND][TENSORFLOW] support multiply outputs * [TENSORFLOW][TEST] add tf_testing.AddShapesToGraphDef test * update frontend test * retrigger CI --- python/tvm/relay/testing/tf.py | 9 +++++++-- tests/python/frontend/tensorflow/test_forward.py | 6 +----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/testing/tf.py b/python/tvm/relay/testing/tf.py index 1dbbf14..dfe5801 100644 --- a/python/tvm/relay/testing/tf.py +++ b/python/tvm/relay/testing/tf.py @@ -66,6 +66,11 @@ def ProcessGraphDefParam(graph_def): return graph_def +def convert_to_list(x): + if not isinstance(x, list): + x = [x] + return x + def AddShapesToGraphDef(session, out_node): """ Add shapes attribute to nodes of the graph. Input graph here is the default graph in context. @@ -74,7 +79,7 @@ def AddShapesToGraphDef(session, out_node): ---------- session : tf.Session Tensorflow session - out_node : String + out_node : String or List Final output node of the graph. Returns @@ -87,7 +92,7 @@ def AddShapesToGraphDef(session, out_node): graph_def = tf_compat_v1.graph_util.convert_variables_to_constants( session, session.graph.as_graph_def(add_shapes=True), - [out_node], + convert_to_list(out_node), ) return graph_def diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 31c5480..3d033d3 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -171,11 +171,7 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False, with tf.Session() as sess: if init_global_variables: sess.run(variables.global_variables_initializer()) - final_graph_def = tf.graph_util.convert_variables_to_constants( - sess, - sess.graph.as_graph_def(add_shapes=True), - out_node, - ) + final_graph_def = tf_testing.AddShapesToGraphDef(sess, out_node) tf_output = run_tf_graph(sess, in_data, in_name, out_name) for device in ["llvm", "cuda"]: -- 2.7.4