# First, parse all control flow nodes.
# Convert tf.cond to Branch and tf.while_loop to Loop.
sorted_cf_nodes = []
- current_node_name_prefix = None
- exits = []
+ exit_pos_map = {}
+ ordered_prefix = []
# Sort control flow nodes to move all Exit nodes to the end
# of corresponding while_loop block.
- for i, node in enumerate(control_flow_nodes):
- node_name_prefix = node.name.rsplit('/', 1)[0]
- if current_node_name_prefix is None or current_node_name_prefix != node_name_prefix:
- if node_name_prefix in self._while_loop_name_set:
- sorted_cf_nodes.extend(exits)
- exits.clear()
- current_node_name_prefix = node_name_prefix
-
+ for node in control_flow_nodes:
+ loop_name = find_parent_loop_name(node.name, self._while_loop_name_set)
if node.op == "Exit":
- exits.append(node)
+ if loop_name not in exit_pos_map:
+ ordered_prefix.append(loop_name)
+ exit_pos_map[loop_name] = len(sorted_cf_nodes)
+ sorted_cf_nodes.append(node)
+ elif loop_name in self._while_loop_name_set:
+ if loop_name not in exit_pos_map:
+ sorted_cf_nodes.append(node)
+ else:
+ sorted_cf_nodes.insert(exit_pos_map[loop_name], node)
+ for j in range(ordered_prefix.index(loop_name), len(ordered_prefix)):
+ exit_pos_map[ordered_prefix[j]] += 1
else:
sorted_cf_nodes.append(node)
- if i == len(control_flow_nodes) - 1:
- sorted_cf_nodes.extend(exits)
-
for node in sorted_cf_nodes:
self._sorted_cf_node_names.append(node.name)