self._ignores.append('_node_name')
self._ignores.append('is_training')
self._ignores.append('_target_layout')
+ self._ignores.append('_input_0d_mismatch')
# apply custom check
if self._custom_check:
attr['data_format'] = attr['data_format'].decode("utf-8")
flip_layout = False
- input_shape = attr['_input_shapes'][inputs[0]][0]
+ input_shape = attr['_input_shapes'][inputs[0]]
if attr['data_format'] == 'NHWC':
attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2])
raise TypeError("Unsupported data_format type : {}".format(attr['data_format']))
if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
- tmp_shape = attr['_input_shapes'][inputs[0]][0]
+ tmp_shape = attr['_input_shapes'][inputs[0]]
input_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)]
inputs[0] = _op.transpose(inputs[0], axes=(0, 3, 1, 2))
attr['data_format'] = "NCHW"
# NCHW Layout require weights transpose
if attr['data_format'] == 'NCHW':
- tmp_shape = attr['_input_shapes'][inputs[1]][0]
+ tmp_shape = attr['_input_shapes'][inputs[1]]
tmp_shape = [tmp_shape[ii] for ii in (3, 2, 0, 1)]
inputs[1] = _op.transpose(inputs[1], axes=(3, 2, 0, 1))
- attr['_input_shapes'][inputs[1]] = [tmp_shape]
+ attr['_input_shapes'][inputs[1]] = tmp_shape
- input_shape = attr['_input_shapes'][inputs[0]][0]
- weights_shape = attr['_input_shapes'][inputs[1]][0]
+ input_shape = attr['_input_shapes'][inputs[0]]
+ weights_shape = attr['_input_shapes'][inputs[1]]
if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
attr['channels'] = input_shape[3] * depth_mult
if 'dilations' in attr:
- attr['dilations'] = (attr['dilations'][0], attr['dilations'][1])
+ attr['dilations'] = (attr['dilations'][1], attr['dilations'][2])
attr['strides'] = (attr['strides'][1], attr['strides'][2])
elif attr['data_format'] == 'NCHW':
depth_mult, _, kernel_h, kernel_w = weights_shape
in_h = input_shape[2]
in_w = input_shape[3]
- pad_v = _get_pad_pair(in_h, kernel_h, stride_h)
- pad_h = _get_pad_pair(in_w, kernel_w, stride_w)
+ dilation_h = attr['dilations'][0]
+ dilation_w = attr['dilations'][1]
+ dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
+ dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
+ pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h)
+ pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w)
+
if attr['data_format'] == 'NHWC':
inputs[0] = _op.nn.pad(data=inputs[0],
dim_input = inputs.pop(1)
axis = params[dim_input.name_hint]
params.pop(dim_input.name_hint)
- return AttrCvt(op_name="expand_dims", ignores=['Tdim'],
- extras={'axis': int(axis.asnumpy()[0])})(inputs, attr)
+ return _expand_dims_0d_aware(inputs[0], attr, axis=axis.asnumpy()[0])
return _impl
def _resize_bilinear():
return _impl
+def _undef():
+ def _impl(inputs, attr, params):
+ return _sym.__undef__()
+ return _impl
+
def _identity():
def _impl(inputs, attr, params):
return inputs[0]
def _pack():
def _impl(inputs, attr, params):
axis = int(attr["axis"])
- inputs_reshaped = [_op.expand_dims(i, axis=axis, num_newaxis=1) for i in inputs]
+ inputs_reshaped = [_expand_dims_0d_aware(i, attr, axis=axis, num_newaxis=1) for i in inputs]
return _op.concatenate(inputs_reshaped, axis)
return _impl
+def _slice():
+ def _impl(inputs, attr, params):
+ begin = params.pop(_get_name_hint(inputs[1])).asnumpy().tolist()
+ size = params.pop(_get_name_hint(inputs[2])).asnumpy().tolist()
+ data_shape = attr['_input_shapes'][inputs[0]]
+ data_dim = len(data_shape)
+ end = size
+ for i in range(data_dim):
+ if size[i] == -1:
+ end[i] = data_shape[i] - begin[i]
+ else:
+ end[i] += begin[i]
+ return _op.strided_slice(inputs[0], begin=begin, end=size)
+ return _impl
+
+
def _reshape():
def _impl(inputs, attr, params):
try:
def _shape():
def _impl(inputs, attr, params):
- return np.array(attr['_input_shapes'][inputs[0]][0], dtype='int32')
+ return np.array(attr['_input_shapes'][inputs[0]], dtype='int32')
return _impl
def _fill():
new_axis_mask = int(attr.get('new_axis_mask', 0))
shrink_axis_mask = int(attr.get('shrink_axis_mask', 0))
data_shape = attr['_input_shapes'][inputs[0]]
- data_dim = len(data_shape[0])
+ data_dim = len(data_shape)
stride_dim = len(stride)
def _transform_mask(stride_dim, ellipsis_mask):
+ new_axes_after_ellipsis), data_dim)
for i in range(final_index, to_index):
m_begin[final_index] = 0
- m_end[final_index] = data_shape[0][final_index]
+ m_end[final_index] = data_shape[final_index]
m_stride[final_index] = 1
fshape_indices.append(final_index)
final_index += 1
if final_index == len(m_begin):
break
if mask & begin_mask:
- m_begin[final_index] = data_shape[0][final_index] \
+ m_begin[final_index] = data_shape[final_index] \
if stride[index] < 0 else 0
elif begin[index]:
m_begin[final_index] = begin[index]
if mask & end_mask:
m_end[final_index] = 0 if stride[index] < 0 \
- else data_shape[0][final_index]
+ else data_shape[final_index]
elif end[index]:
m_end[final_index] = end[index]
m_stride[final_index] = stride[index]
if mask & shrink_axis_mask:
#Tensorflow make axis with shrink_axis_mask as dimension 1
- m_begin[final_index] = data_shape[0][final_index] + begin[index] \
+ m_begin[final_index] = data_shape[final_index] + begin[index] \
if begin[index] < 0 else begin[index]
m_end[final_index] = begin[index] + 1
m_stride[final_index] = 1
pass
else:
final_output.append(out_shape[gather_index])
+ # Prevent 0-dim tensors which are not accepted by Relay
+ if not final_output:
+ final_output.append(1)
return _op.reshape(out, newshape=tuple(final_output))
return _impl
def _rank():
def _impl(inputs, attr, params):
- input_shapes = attr['_input_shapes'][inputs[0]]
- assert len(inputs) == 1
+ input_shape = attr['_input_shapes'][inputs[0]]
name = attr["_node_name"]
- params[name] = tvm.nd.array([len(input_shapes[0])])
+ params[name] = tvm.nd.array([len(input_shape)])
return [_expr.var(name,
shape=params[name].shape,
dtype='int32')]
)(inputs, attr)
return _impl
+def _split(has_size_vector):
+ # TF documentation https://www.tensorflow.org/api_docs/python/tf/split
+ def _impl(inputs, attr, params):
+ try:
+ # order and number of inputs are different:
+ # if has_size_vector:
+ # https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/split-v
+ # else:
+ # https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/split
+
+ # in addition, `axis` and `num_or_size_splits` can be tensors in TensorFlow,
+ # we can only support constants
+ if has_size_vector:
+ input_node_index = 0
+ input_axis_index = 2
+ size_splits_input_name = _get_name_hint(inputs[1])
+ size_splits = params[size_splits_input_name].asnumpy()
+ section_beginnings = np.cumsum(size_splits)[:-1]
+ indices_or_sections = tuple(section_beginnings)
+ else:
+ input_node_index = 1
+ input_axis_index = 0
+ indices_or_sections = attr['num_split']
+ input_node = inputs[input_node_index]
+ axis_input_name = _get_name_hint(inputs[input_axis_index])
+ axis_input_value = params[axis_input_name].asnumpy()[0]
+ except (IndexError, KeyError):
+ raise TypeError( \
+ "Unsupported argument for split: `axis` and `num_or_size_splits` " \
+ "should be constants")
+ return _op.split(input_node,
+ indices_or_sections=indices_or_sections,
+ axis=int(axis_input_value))
+ return _impl
+
+def _unpack():
+ def _impl(inputs, attr, params):
+ input_node = inputs[0]
+ axis = attr['axis']
+ input_shape = attr['_input_shapes'][input_node]
+ axis_length = input_shape[axis]
+ if axis_length < 0:
+ raise TypeError("Unstack with unknown axis length")
+ splitted = _op.split(input_node,
+ indices_or_sections=axis_length,
+ axis=axis)
+ #name=attr.get('_node_name', 'unstack'))
+ if axis == 0:
+ axis = None
+ else:
+ axis = [axis]
+ return _expr.TupleWrapper(
+ _expr.Tuple([_op.squeeze(split_item, axis=axis) \
+ for split_item in splitted]), len(splitted))
+ return _impl
+
+def _expand_dims_0d_aware(data, attr, axis, num_newaxis=1):
+ if data in attr['_input_0d_mismatch']:
+ return data if num_newaxis == 1 else \
+ AttrCvt(op_name="expand_dims", ignores=['Tdim', 'N'],
+ extras={'axis': int(axis), 'num_newaxis': int(num_newaxis-1)})([data], attr)
+
+ return AttrCvt(op_name="expand_dims", ignores=['Tdim', 'N'],
+ extras={'axis': int(axis), 'num_newaxis': int(num_newaxis)})([data], attr)
+
+
def _softmax():
def _impl(inputs, attr, params):
return AttrCvt(op_name='softmax',
'Add' : _elemwise('add'),
'Sub' : _elemwise('subtract'),
'Mul' : _elemwise('multiply'),
+ 'RealDiv' : _elemwise('div'),
'Maximum' : _elemwise('maximum'),
'Minimum' : _elemwise('minimum'),
'Sum' : _sum(),
'Square' : _square(),
'Pack' : _pack(),
+ 'Slice' : _slice(),
'LeakyRelu' : AttrCvt('leaky_relu'),
'Relu' : AttrCvt('relu'),
'Reshape' : _reshape(),
'GreaterEqual' : _broadcast('greater_equal'),
'Equal' : _broadcast('equal'),
'NotEqual' : _broadcast('not_equal'),
+ 'Split' : _split(False),
+ 'SplitV' : _split(True),
+ 'Unpack' : _unpack(),
}
def _LSTMBlockCell():
forget_bias = attr.pop('forget_bias')
input_shape = attr['_input_shapes'][inputs[0]]
weight_shape = attr['_input_shapes'][inputs[3]]
- batch_size, input_size = input_shape[0][0], input_shape[0][1]
- num_hidden_layers = weight_shape[0][1]
+ batch_size, input_size = input_shape[0], input_shape[1]
+ num_hidden_layers = weight_shape[1]
num_hidden = num_hidden_layers // 4
in_data = _op.reshape(in_data,
input_shape = attr['_input_shapes'][inputs[0]]
weight_shape = attr['_input_shapes'][inputs[3]]
- batch_size = input_shape[0][0]
- num_hidden = weight_shape[0][1] // 4
+ batch_size = input_shape[0]
+ num_hidden = weight_shape[1] // 4
if layer == 0:
#Create initial states placeholder in case of first layer
self._output_shapes = {}
self._num_param = 0
self._num_rnn_layer = False
+ self._outputs_are_0d = {}
+ self._input_shapes = {}
def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
"""Construct relay nodes from tensorflow graph definition - GraphDef.
# Operator name 'Const' is treated as a parameter to build params dict.
input_shapes = {}
+ input_0d_mismatch = set()
attr = self._parse_attr(node.attr)
# Variable converted to Const will not have only value attr
elif shape and node.name in shape:
# Give priority to user argument.
self._output_shapes[node.name] = [shape[node.name]]
+ elif node.op == 'Placeholder':
+ self._output_shapes[node.name] = [self._input_shapes[node.name]]
elif '_output_shapes' in attr:
self._output_shapes[node.name] = \
[tensor_util.TensorShapeProtoToList(tshape) \
else:
# Keep the list indexable to avoid key error.
# Actual value will be filled after node creation.
+ # Will infer shapes if the graph is not frozen with add_shapes=True
self._output_shapes[node.name] = [None]
+ self._outputs_are_0d[node.name] = [ \
+ not shape if isinstance(tshape, list) else False \
+ for tshape in self._output_shapes[node.name]]
+
if node.op == "Placeholder":
self._output_shapes[node.name] = [self._input_shapes[node.name]]
self._nodes[node.name] = [_expr.var(node.name,
# Fill shapes for all inputs in a list
inputs = []
for i in node.input:
- if i in self._nodes:
- inputs.append(self._nodes[i][0])
- input_shapes[self._nodes[i][0]] = self._output_shapes[i]
+ # Some TensorFlow operators internally maintain execution layers
+ # and their output name includes the layer number along with
+ # graph node name. E.g. the node name is 'Model/RNN/cell_0/RnnCell', but the
+ # output tensor name is 'Model/RNN/cell_0/RnnCell:0'. In this case,
+ # the number has to be ignored for single-output nodes.
+ # On the other hand, for multi-output nodes the number is the output index,
+ # and the lack of the number implies 0.
+ tensor_name = i.split(':')
+ node_name = tensor_name[0]
+ if node_name in self._nodes:
+ in_sym = self._nodes[node_name]
+ if isinstance(in_sym, _expr.TupleWrapper):
+ tensor_slot = int(tensor_name[1]) if len(tensor_name) > 1 else 0
+ in_sym = [in_sym[tensor_slot]]
+ input_shape = self._output_shapes[node_name][tensor_slot]
+ else:
+ tensor_slot = 0
+ input_shape = self._output_shapes[node_name][0]
+ inputs.append(in_sym[0])
+ input_shapes[in_sym[0]] = input_shape
+ # This means the node is 1d in Relay and 0d in TF.
+ # See `_expand_dims_0d_aware`.
+ if self._outputs_are_0d[node_name][tensor_slot] and input_shape:
+ input_0d_mismatch.add(in_sym)
+
attr['_input_shapes'] = input_shapes
+ attr['_input_0d_mismatch'] = input_0d_mismatch
op = self._convert_operator(node.op, inputs, attr, graph)
# Infer shapes even without specifying "add_shapes=True"
if output_shapes == [None]:
- out_type = ir_pass.infer_type(self._nodes[node.name][0])
- self._output_shapes[node.name] = [get_const_tuple(out_type.checked_type.shape)]
+ out_shapes = []
+ for node_item in self._nodes[node.name]:
+ out_type = ir_pass.infer_type(node_item)
+ out_shapes.append(get_const_tuple(out_type.checked_type.shape))
+ self._output_shapes[node.name] = out_shapes
if self._output_shapes[node.name] and shape and node.name in shape:
assert self._output_shapes[node.name] == list(shape[node.name])
# Infer shapes if passed explicitely
node_output = self._nodes[node.name]
- out_type = ir_pass.infer_type(node_output[0])
- self._output_shapes[node.name] = [get_const_tuple(out_type.checked_type.shape)]
-
+ if shape and (not self._output_shapes[node.name][0]
+ or -1 in self._output_shapes[node.name][0]):
+ out_shapes = []
+ for node_item in node_output:
+ out_type = ir_pass.infer_type(node_item)
+ out_shapes.append(get_const_tuple(out_type.checked_type.shape))
+ self._output_shapes[node.name] = out_shapes
out = []
if outputs is None:
out = op
else:
- out = [self._nodes[out_name][0] for out_name in outputs]
+ for out_name in outputs:
+ if ":" in out_name:
+ out_name, out_num = out_name.split(":")
+ out_num = int(out_num)
+ out.append(self._nodes[out_name][out_num])
+ else:
+ out.append(self._nodes[out_name][0])
#Add the RNN outputs also with 'head' nodes of the relay graph
if self._num_rnn_layer:
if no_gpu and device == 'cuda':
continue
- tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, target=device)
+ tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, target=device,
+ out_names=out_name, num_output=len(out_name))
# since the names from tensorflow and relay runs are not exactly same,
# first len(tf_output) will be compared
for i in range(len(tf_output)):
if is_gpu_available():
input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
- kwargs['data_layout'] = 'NCHW'
+ kwargs['data_format'] = 'NCHW'
_test_pooling_iteration(input_shape, **kwargs)
def test_forward_pooling():
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32')
in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype='float32')
- strides = [1] + strides + [1]
- dilations = [1] + dilations + [1]
+ if data_format == 'NHWC':
+ strides = [1] + strides + [1]
+ dilations = [1] + dilations + [1]
+ else:
+ strides = [1, 1] + strides
+ dilations = [1, 1] + dilations
nn_ops.conv2d(in_data,
in_filter,
_test_gather((3,3,3), (1,1,2), [[[1,0]]], 2, 'int32')
_test_gather((4,3,5,6), (1,4), [[2,1,0,0]], 0, 'float32')
+#######################################################################
+# Split
+# -----
+
+def _test_split(in_shape, axis, num_or_size_splits, dtype):
+ np_data = np.random.uniform(-5, 5, size=in_shape).astype(dtype)
+
+ """ One iteration of a Split """
+ tf.reset_default_graph()
+ in_data = tf.placeholder(dtype, in_shape, name="in_data")
+ num_split = len(num_or_size_splits) if isinstance(num_or_size_splits, list) else num_or_size_splits
+ tf.split(in_data, num_or_size_splits, axis=axis)
+
+ compare_tf_with_tvm([np_data], ['in_data:0'], [f'split:{n}' for n in range(num_split)])
+
+ # and now test together with concat
+ tf.reset_default_graph()
+ in_data = tf.placeholder(dtype, in_shape, name="in_data")
+ splitted = tf.split(in_data, num_or_size_splits, axis=axis)
+ tf.concat(splitted, axis)
+
+ compare_tf_with_tvm([np_data], 'in_data:0', 'concat:0')
+
+def test_forward_split():
+ '''test split layer'''
+ # rank 1
+ _test_split((3,), 0, 1, 'float32')
+ _test_split((3,), 0, 3, 'float32')
+ _test_split((6,), 0, 3, 'float32')
+ # rank 2
+ _test_split((6, 2), 0, 3, 'float32')
+ _test_split((2, 6), 1, 6, 'float32')
+ # rank 3
+ _test_split((6, 2, 4), 0, 2, 'int32')
+ _test_split((2, 6, 4), 1, 3, 'float32')
+ _test_split((2, 4, 6), 2, 1, 'float32')
+ # rank 4
+ _test_split((6, 1, 3, 5), 0, 3, 'float32')
+ _test_split((1, 6, 3, 5), 1, 3, 'float32')
+ _test_split((1, 3, 6, 5), 2, 3, 'float32')
+ _test_split((1, 3, 5, 6), 3, 3, 'float32')
+ # split along negative axis
+ _test_split((6, 1, 3, 5), -4, 3, 'float32')
+ _test_split((1, 6, 3, 5), -3, 3, 'float32')
+ _test_split((1, 3, 6, 5), -2, 3, 'float32')
+ _test_split((1, 3, 5, 6), -1, 3, 'float32')
+ # size_splits list
+ _test_split((6,), 0, [1, 2, 3], 'int32')
+ _test_split((3, 6, 4), -2, [1, 4, 1], 'float32')
+
+
+#######################################################################
+# Unstack
+# -------
+
+def _test_unstack(ip_shape, axis, dtype):
+ np_data = np.random.uniform(-5, 5, size=ip_shape).astype(dtype)
+
+ tf.reset_default_graph()
+ in_data = tf.placeholder(dtype, ip_shape, name="in_data")
+ tf.unstack(in_data, axis=axis)
+
+ compare_tf_with_tvm([np_data], ['in_data:0'], [f'unstack:{n}' for n in range(ip_shape[axis])])
+
+ tf.reset_default_graph()
+ in_data = tf.placeholder(dtype, ip_shape, name="in_data")
+ tf.stack(tf.unstack(in_data, axis=axis), axis=axis)
+
+ compare_tf_with_tvm([np_data], ['in_data:0'], 'stack:0')
+
+def test_forward_unstack():
+ '''test unstack layer'''
+ _test_unstack((6,), 0, 'int32')
+ _test_unstack((2,6), 1, 'float64')
+ # negative axis
+ _test_unstack((1,4), -1, 'int32')
+ _test_unstack((3,6,4), -2, 'float32')
+
#######################################################################
# Multi Input to graph
_test_resize_bilinear((4, 16, 32, 32), [50, 50], False)
_test_resize_bilinear((6, 32, 64, 64), [20, 20], True)
+#######################################################################
+# Crop to bounding box
+# --------------------
+
+def _test_crop(in_shape, off_h, off_w, tar_h, tar_w):
+ """ Crop to bounding box """
+ data = np.random.uniform(size=in_shape).astype('float32')
+ with tf.Graph().as_default():
+ in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
+ tf.image.crop_to_bounding_box(in_data, off_h, off_w, tar_h, tar_w)
+ compare_tf_with_tvm(data, 'Placeholder:0', 'crop_to_bounding_box/Slice:0')
+
+def test_forward_crop():
+ """ Crop to bounding box """
+ _test_crop((1, 224, 224, 3), 20, 20, 120, 120)
+
#######################################################################
# LSTM
#######################################################################
# ResnetV2
-# ---------
+# --------
def test_forward_resnetv2():
'''test resnet model'''
if is_gpu_available():
with tf.Session() as sess:
tf_output = run_tf_graph(sess, data, 'input_tensor:0', out_node + ':0')
- tvm_output = run_tvm_graph(graph_def, data, 'input_tensor', tf_output.shape, 'float32')
- tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5)
+ for device in ["llvm", "cuda"]:
+ ctx = tvm.context(device, 0)
+ if not ctx.exist:
+ print("Skip because %s is not enabled" % device)
+ continue
+ tvm_output = run_tvm_graph(graph_def, data, 'input_tensor', len(tf_output), target=device)
+ tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5)
#######################################################################
# PTB
test_forward_squeeze()
test_forward_pack()
test_forward_resize_bilinear()
+ test_forward_crop()
test_forward_pad()
test_forward_gather()
test_forward_stridedslice()
+ test_forward_split()
+ test_forward_unstack()
# Activations
test_forward_sigmoid()