branch = self._branches[node_name_prefix]
false_br = self._backtrack_construct(node.input[0])
true_br = self._backtrack_construct(node.input[1])
- assert len(true_br) == 1
- assert len(false_br) == 1
- branch.true_branch = true_br[0]
- branch.false_branch = false_br[0]
- op = [branch.if_node()]
+ branch.true_branch = true_br
+ branch.false_branch = false_br
+ op = branch.if_node()
if node_name_prefix not in self._while_loop_name_set:
try:
cond_val = np.all(_infer_value(branch.cond, self._params,
self._mod).asnumpy())
if cond_val:
- op = [branch.true_branch]
+ op = branch.true_branch
else:
- op = [branch.false_branch]
+ op = branch.false_branch
except Exception:
- op = [branch.if_node()]
+ op = branch.if_node()
elif node.op == "Exit":
loop = self._loops[node_name_prefix]
if exit_number == j:
body_pos = i
break
- op = [_expr.TupleGetItem(expr, body_pos)]
+ op = _expr.TupleGetItem(expr, body_pos)
elif node.op == "Enter":
op = self._backtrack_construct(node.input[0])
elif node.op == "LoopCond":
op = self._backtrack_construct(node.input[0])
- assert len(op) == 1
- self._loops[node_name_prefix].cond = op[0]
+ self._loops[node_name_prefix].cond = op
elif node.op == "Switch":
op = self._backtrack_construct(node.input[0])
cond = self._backtrack_construct(node.input[1])
- assert len(op) == 1
if _in_while_loop(self._control_flow_node_map, node_name_prefix):
if node_name_prefix not in self._loop_var_order:
self._loop_var_order[node_name_prefix] = []
else:
self._loop_var_order[node_name_prefix].\
append(int(node.name.split("Switch_")[-1]))
- self._loops[node_name_prefix].loop_vars.append(op[0])
+ self._loops[node_name_prefix].loop_vars.append(op)
else:
if node_name_prefix not in self._branches:
self._branches[node_name_prefix] = Branch()
- self._branches[node_name_prefix].cond = cond[0]
+ self._branches[node_name_prefix].cond = cond
elif node.op == "NextIteration":
if node_name_prefix not in self._loop_body_order:
self._loop_body_order[node_name_prefix] = []
self._loop_body_order[node_name_prefix].\
append(int(node.name.split("NextIteration_")[-1]))
op = self._backtrack_construct(node.input[0])
-
- assert len(op) == 1
- self._loops[node_name_prefix].body.append(op[0])
+ self._loops[node_name_prefix].body.append(op)
else:
raise Exception("Cannot identify control flow operator: " +
"{}".format(node.op))
op : relay.Expr
Converted relay expression
"""
- node_name = node_name.split(':')[0].split("^")[-1]
+ input_op_name = node_name.split(':')[0].split("^")[-1]
- if node_name not in self._nodes:
- node = self._tf_node_map[node_name]
+ if input_op_name not in self._nodes:
+ node = self._tf_node_map[input_op_name]
attr = self._parse_attr(node.attr)
if node.op in _control_flow_nodes:
attr,
self._control_flow_node_map)
else:
- attr["_output_shapes"] = self._output_shapes[node_name]
+ attr["_output_shapes"] = self._output_shapes[input_op_name]
attr["_node_name"] = node.name
attr["_target_layout"] = self._layout
- inputs = []
- for iname in node.input:
- in_op = self._backtrack_construct(iname)
- if isinstance(in_op, _expr.TupleWrapper):
- tn = iname.split(':')
- tensor_slot = int(tn[1]) if len(tn) > 1 else 0
- in_op = in_op[tensor_slot]
- else:
- in_op = in_op[0]
-
- inputs.append(in_op)
+ inputs = [self._backtrack_construct(iname) for iname in node.input]
op = self._convert_operator(node.op, inputs, attr, self._graph)
if isinstance(op, np.ndarray):
node_hash = s_hash(op) if isinstance(op, _expr.Tuple) else s_hash(op[0])
self._hash2tfnode[node_hash] = node
- self._nodes[node_name] = op
+ self._nodes[input_op_name] = op
+
+ out = self._nodes[input_op_name]
+
+ if isinstance(out, _expr.TupleWrapper):
+ tn = node_name.split(':')
+ tensor_slot = int(tn[1]) if len(tn) > 1 else 0
+ return out[tensor_slot]
- return self._nodes[node_name]
+ return out[0]
def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
"""Load tensorflow graph which is a python tensorflow graph object into relay.
tf.disable_v2_behavior()
except ImportError:
import tensorflow as tf
+from tensorflow.python.ops import control_flow_ops
import numpy as np
from tvm import nd
from tvm import relay
check_equal(graph, tf_out, {dname: np_data})
+def test_switch():
+ graph = tf.Graph()
+
+ with graph.as_default():
+ data_np = np.random.uniform(0, 5, size=(2, 4, 5, 1)).astype('float32')
+ dname = 'data'
+ flag_name = 'flag'
+ data = tf.placeholder(shape=data_np.shape, dtype=data_np.dtype, name=dname)
+ split = tf.split(data, 2, axis=0)
+ flag = tf.placeholder(shape={}, dtype=tf.bool, name=flag_name)
+ output_false, output_true = control_flow_ops.switch(split[1], flag)
+ with tf.Session() as sess:
+ tf_out = sess.run(output_false, feed_dict={data.name: data_np, flag.name: False})
+
+ check_equal(graph, tf_out, {dname: data_np, flag_name: False})
+
+
if __name__ == "__main__":
# tf.while_loop
test_vanilla_loop()
test_cond_in_loop()
test_vanilla_loop_bound()
test_nested_loop_bound()
+
+ test_switch()