From c4439a8046fe61fd4603403d346db47c7fc8886d Mon Sep 17 00:00:00 2001 From: Siva Date: Thu, 16 May 2019 09:25:38 +0530 Subject: [PATCH] [TENSORLFOW] PlaceholderWithDefault (limited) implementation. (#3184) --- python/tvm/relay/frontend/tensorflow.py | 6 +++--- tests/python/frontend/tensorflow/test_forward.py | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 4bd78b4..b5a9ea5 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1740,7 +1740,7 @@ class GraphProto(object): for node in graph.node: node_name_prefix = node.name.rsplit('/', 1)[0] control_flow_node_map[node_name_prefix].add(node.op) - if node.op == 'Placeholder': + if node.op == 'Placeholder' or node.op == 'PlaceholderWithDefault': # Give priority to user argument. if shape and node.name in shape: self._input_shapes[node.name] = list(shape[node.name]) @@ -1800,7 +1800,7 @@ class GraphProto(object): attr = self._parse_attr(node.attr) - elif node.op != "Placeholder": + elif node.op != "Placeholder" and node.op != 'PlaceholderWithDefault': # Pass the parsed shapes instead attr["_output_shapes"] = output_shapes = self._output_shapes[node.name] @@ -1925,7 +1925,7 @@ class GraphProto(object): """ missing_operators = set() for node in graph.node: - if node.op == "Placeholder": + if node.op == "Placeholder" or node.op == 'PlaceholderWithDefault': pass elif node.op == "Const": pass diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 2f1cc2f..90ee758 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1541,6 +1541,24 @@ def test_forward_reduce_prod(): _test_forward_reduce_prod((5, 5), 0, True) _test_forward_reduce_prod((5, 5), 1, True) + +####################################################################### +# PlaceholderWithDefault +# ---------------------- +def test_placeholder(): + with tf.Graph().as_default(): + in_data1 = np.random.uniform(-5, 5, size=(3, 4, 5)).astype(np.float32) + var1 = tf.Variable(in_data1, name='in1') + var2 = array_ops.placeholder_with_default(var1, None, name='place1') + + in_data2 = np.random.uniform(-5, 5, size=(3, 4, 5)).astype(np.float32) + place1 = array_ops.placeholder(shape=in_data1.shape, dtype=in_data1.dtype, name='in2') + + out1 = tf.math.add(var1, var2, name='out1') + out2 = tf.math.add(out1, place1, name='out2') + + compare_tf_with_tvm([in_data1, in_data2], ['place1:0', 'in2:0'], 'out2:0', init_global_variables=True) + ####################################################################### # Main # ---- @@ -1590,6 +1608,7 @@ if __name__ == '__main__': test_forward_multi_input() test_forward_multi_output() test_forward_variable() + test_placeholder() # NN test_forward_convolution() -- 2.7.4