[Frontend][TensorFlow] Improve TensorFlow control flow nodes ordering (#6387)
authorYao Wang <kevinthesunwy@gmail.com>
Thu, 3 Sep 2020 16:25:14 +0000 (09:25 -0700)
committerGitHub <noreply@github.com>
Thu, 3 Sep 2020 16:25:14 +0000 (09:25 -0700)
* Improve TensorFlow control flow nodes ordering

* Fix Lint

python/tvm/relay/frontend/tensorflow.py

index 24f1b8b..02c8204 100644 (file)
@@ -2716,26 +2716,27 @@ class GraphProto(object):
         # First, parse all control flow nodes.
         # Convert tf.cond to Branch and tf.while_loop to Loop.
         sorted_cf_nodes = []
-        current_node_name_prefix = None
-        exits = []
+        exit_pos_map = {}
+        ordered_prefix = []
         # Sort control flow nodes to move all Exit nodes to the end
         # of corresponding while_loop block.
-        for i, node in enumerate(control_flow_nodes):
-            node_name_prefix = node.name.rsplit('/', 1)[0]
-            if current_node_name_prefix is None or current_node_name_prefix != node_name_prefix:
-                if node_name_prefix in self._while_loop_name_set:
-                    sorted_cf_nodes.extend(exits)
-                    exits.clear()
-                    current_node_name_prefix = node_name_prefix
-
+        for node in control_flow_nodes:
+            loop_name = find_parent_loop_name(node.name, self._while_loop_name_set)
             if node.op == "Exit":
-                exits.append(node)
+                if loop_name not in exit_pos_map:
+                    ordered_prefix.append(loop_name)
+                    exit_pos_map[loop_name] = len(sorted_cf_nodes)
+                sorted_cf_nodes.append(node)
+            elif loop_name in self._while_loop_name_set:
+                if loop_name not in exit_pos_map:
+                    sorted_cf_nodes.append(node)
+                else:
+                    sorted_cf_nodes.insert(exit_pos_map[loop_name], node)
+                    for j in range(ordered_prefix.index(loop_name), len(ordered_prefix)):
+                        exit_pos_map[ordered_prefix[j]] += 1
             else:
                 sorted_cf_nodes.append(node)
 
-            if i == len(control_flow_nodes) - 1:
-                sorted_cf_nodes.extend(exits)
-
         for node in sorted_cf_nodes:
             self._sorted_cf_node_names.append(node.name)