[TF] Support TupleWrapper as direct ancestor of control flow ops (#5639)
authorlixiaoquan <radioheads@163.com>
Tue, 26 May 2020 18:29:29 +0000 (02:29 +0800)
committerGitHub <noreply@github.com>
Tue, 26 May 2020 18:29:29 +0000 (11:29 -0700)
python/tvm/relay/frontend/tensorflow.py
tests/python/frontend/tensorflow/test_control_flow.py

index ab9e9e6..d4b73f9 100644 (file)
@@ -3073,21 +3073,19 @@ class GraphProto(object):
                 branch = self._branches[node_name_prefix]
                 false_br = self._backtrack_construct(node.input[0])
                 true_br = self._backtrack_construct(node.input[1])
-                assert len(true_br) == 1
-                assert len(false_br) == 1
-                branch.true_branch = true_br[0]
-                branch.false_branch = false_br[0]
-                op = [branch.if_node()]
+                branch.true_branch = true_br
+                branch.false_branch = false_br
+                op = branch.if_node()
                 if node_name_prefix not in self._while_loop_name_set:
                     try:
                         cond_val = np.all(_infer_value(branch.cond, self._params,
                                                        self._mod).asnumpy())
                         if cond_val:
-                            op = [branch.true_branch]
+                            op = branch.true_branch
                         else:
-                            op = [branch.false_branch]
+                            op = branch.false_branch
                     except Exception:
-                        op = [branch.if_node()]
+                        op = branch.if_node()
         elif node.op == "Exit":
             loop = self._loops[node_name_prefix]
 
@@ -3113,17 +3111,15 @@ class GraphProto(object):
                 if exit_number == j:
                     body_pos = i
                     break
-            op = [_expr.TupleGetItem(expr, body_pos)]
+            op = _expr.TupleGetItem(expr, body_pos)
         elif node.op == "Enter":
             op = self._backtrack_construct(node.input[0])
         elif node.op == "LoopCond":
             op = self._backtrack_construct(node.input[0])
-            assert len(op) == 1
-            self._loops[node_name_prefix].cond = op[0]
+            self._loops[node_name_prefix].cond = op
         elif node.op == "Switch":
             op = self._backtrack_construct(node.input[0])
             cond = self._backtrack_construct(node.input[1])
-            assert len(op) == 1
             if _in_while_loop(self._control_flow_node_map, node_name_prefix):
                 if node_name_prefix not in self._loop_var_order:
                     self._loop_var_order[node_name_prefix] = []
@@ -3132,11 +3128,11 @@ class GraphProto(object):
                 else:
                     self._loop_var_order[node_name_prefix].\
                         append(int(node.name.split("Switch_")[-1]))
-                self._loops[node_name_prefix].loop_vars.append(op[0])
+                self._loops[node_name_prefix].loop_vars.append(op)
             else:
                 if node_name_prefix not in self._branches:
                     self._branches[node_name_prefix] = Branch()
-                self._branches[node_name_prefix].cond = cond[0]
+                self._branches[node_name_prefix].cond = cond
         elif node.op == "NextIteration":
             if node_name_prefix not in self._loop_body_order:
                 self._loop_body_order[node_name_prefix] = []
@@ -3146,9 +3142,7 @@ class GraphProto(object):
                 self._loop_body_order[node_name_prefix].\
                     append(int(node.name.split("NextIteration_")[-1]))
             op = self._backtrack_construct(node.input[0])
-
-            assert len(op) == 1
-            self._loops[node_name_prefix].body.append(op[0])
+            self._loops[node_name_prefix].body.append(op)
         else:
             raise Exception("Cannot identify control flow operator: " +
                             "{}".format(node.op))
@@ -3219,10 +3213,10 @@ class GraphProto(object):
         op : relay.Expr
             Converted relay expression
         """
-        node_name = node_name.split(':')[0].split("^")[-1]
+        input_op_name = node_name.split(':')[0].split("^")[-1]
 
-        if node_name not in self._nodes:
-            node = self._tf_node_map[node_name]
+        if input_op_name not in self._nodes:
+            node = self._tf_node_map[input_op_name]
             attr = self._parse_attr(node.attr)
 
             if node.op in _control_flow_nodes:
@@ -3231,20 +3225,10 @@ class GraphProto(object):
                                                          attr,
                                                          self._control_flow_node_map)
             else:
-                attr["_output_shapes"] = self._output_shapes[node_name]
+                attr["_output_shapes"] = self._output_shapes[input_op_name]
                 attr["_node_name"] = node.name
                 attr["_target_layout"] = self._layout
-                inputs = []
-                for iname in node.input:
-                    in_op = self._backtrack_construct(iname)
-                    if isinstance(in_op, _expr.TupleWrapper):
-                        tn = iname.split(':')
-                        tensor_slot = int(tn[1]) if len(tn) > 1 else 0
-                        in_op = in_op[tensor_slot]
-                    else:
-                        in_op = in_op[0]
-
-                    inputs.append(in_op)
+                inputs = [self._backtrack_construct(iname) for iname in node.input]
                 op = self._convert_operator(node.op, inputs, attr, self._graph)
 
             if isinstance(op, np.ndarray):
@@ -3258,9 +3242,16 @@ class GraphProto(object):
 
             node_hash = s_hash(op) if isinstance(op, _expr.Tuple) else s_hash(op[0])
             self._hash2tfnode[node_hash] = node
-            self._nodes[node_name] = op
+            self._nodes[input_op_name] = op
+
+        out = self._nodes[input_op_name]
+
+        if isinstance(out, _expr.TupleWrapper):
+            tn = node_name.split(':')
+            tensor_slot = int(tn[1]) if len(tn) > 1 else 0
+            return out[tensor_slot]
 
-        return self._nodes[node_name]
+        return out[0]
 
 def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
     """Load tensorflow graph which is a python tensorflow graph object into relay.
index 9777a8d..9003527 100644 (file)
@@ -21,6 +21,7 @@ try:
     tf.disable_v2_behavior()
 except ImportError:
     import tensorflow as tf
+from tensorflow.python.ops import control_flow_ops
 import numpy as np
 from tvm import nd
 from tvm import relay
@@ -368,6 +369,23 @@ def test_nested_loop_bound():
     check_equal(graph, tf_out, {dname: np_data})
 
 
+def test_switch():
+    graph = tf.Graph()
+
+    with graph.as_default():
+        data_np = np.random.uniform(0, 5, size=(2, 4, 5, 1)).astype('float32')
+        dname = 'data'
+        flag_name = 'flag'
+        data = tf.placeholder(shape=data_np.shape, dtype=data_np.dtype, name=dname)
+        split = tf.split(data, 2, axis=0)
+        flag = tf.placeholder(shape={}, dtype=tf.bool, name=flag_name)
+        output_false, output_true = control_flow_ops.switch(split[1], flag)
+        with tf.Session() as sess:
+            tf_out = sess.run(output_false, feed_dict={data.name: data_np, flag.name: False})
+
+    check_equal(graph, tf_out, {dname: data_np, flag_name: False})
+
+
 if __name__ == "__main__":
     # tf.while_loop
     test_vanilla_loop()
@@ -390,3 +408,5 @@ if __name__ == "__main__":
     test_cond_in_loop()
     test_vanilla_loop_bound()
     test_nested_loop_bound()
+
+    test_switch()