return op
class AttrCvt(object):
- """Common attribute conveter. An AttrConverter instance is a callable:
+ """Common attribute converter. An AttrConverter instance is a callable:
```
attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)})
new_op_name, new_attr = attr_converter(attrs)
return False
return _dim_check, "Only 2d kernel supported."
-def _infer_channels(inputs, params, transpose=False):
- """A hack for getting 'channles' or 'units' since tensorflow don't provide
+def _infer_channels(node, params, transpose=False):
+ """A hack for getting 'channels' or 'units' since tensorflow don't provide
these attributes. We check the shape of weights provided to get the number.
"""
- out_type = ir_pass.infer_type(inputs)
- out_shapes = [get_const_tuple(out_type.checked_type.shape)]
- channels = out_shapes[0][0] if not transpose else out_shapes[0][1]
+ out_shape = _infer_shape(node, params)
+ channels = out_shape[0] if not transpose else out_shape[1]
return channels
+def _infer_out_shapes(inputs, params):
+ """A method to get the output shape of intermediate nodes in the relay graph."""
+ return [_infer_shape(inputs, params)]
+
+def _infer_shape(node, params=None):
+ """A method to get the output shape of an intermediate node in the relay graph."""
+ out_type = ir_pass.infer_type(node)
+ return get_const_tuple(out_type.checked_type.shape)
+
+def _get_param(params, input_node):
+ return params.pop(input_node.name_hint).asnumpy()
+
+def _get_num_param(params, input_node):
+ return _get_param(params, input_node)[0]
+
+def _get_list_param(params, input_node):
+ return _get_param(params, input_node).tolist()
+
+def _get_tuple_param(params, input_node):
+ return tuple(_get_param(params, input_node))
+
def _rsqrt():
- def _impl(inputs, attr, *args):
+ def _impl(inputs, attr, params):
inputs.append(tvm.relay.const(-0.5, attr['T'].name))
return AttrCvt(op_name="power")(inputs, attr)
return _impl
try:
# In Tensorflow, `axis` argument is a Tensor, not attribute. We
# support the case where it inputs from a scalar constant.
- axis_input_name = inputs[1].name_hint
- axis_input_vlaue = [params[axis_input_name].asnumpy()[0]]
+ axis_input_value = [_get_num_param(params, inputs[1])]
except (IndexError, KeyError):
raise TypeError( \
"Unsupported argument for `{}` : `axis` should be a constant".format(func_name))
- return func(inputs[0], axis=axis_input_vlaue, keepdims=False)
+ return func(inputs[0], axis=axis_input_value, keepdims=False)
return _impl
def _elemwise(name):
- def _impl(inputs, attr, *args):
+ def _impl(inputs, attr, params):
assert len(inputs) == 2, "{} take 2 inputs, {} given".format(name, len(inputs))
return _get_relay_op(name)(*inputs)
return _impl
def _expand_dims():
def _impl(inputs, attr, params):
dim_input = inputs.pop(1)
- axis = params.pop(_get_name_hint(dim_input)).asnumpy()[0]
+ axis = _get_num_param(params, dim_input)
return AttrCvt(op_name="expand_dims", ignores=['Tdim', 'N'],
extras={'axis': int(axis), 'num_newaxis': 1})(inputs, attr)
return _impl
def _concatV2():
def _impl(inputs, attr, params):
pop_node = inputs.pop(len(inputs)-1)
- axis = params[pop_node.name_hint]
- params.pop(pop_node.name_hint)
+ axis = int(_get_num_param(params, pop_node))
return AttrCvt(
op_name="concatenate", ignores=['T', 'N', 'Tidx'],
- extras={'axis': int(axis.asnumpy()[0])})([inputs], attr)
+ extras={'axis': axis})([inputs], attr)
return _impl
def _concat():
def _impl(inputs, attr, params):
pop_node = inputs.pop(0)
- axis = params[pop_node.name_hint]
- params.pop(pop_node.name_hint)
+ axis = int(_get_num_param(params, pop_node))
return AttrCvt(
op_name="concatenate", ignores=['N'],
- extras={'axis': int(axis.asnumpy()[0])})([inputs], attr)
+ extras={'axis': axis})([inputs], attr)
return _impl
def _pack():
def _slice():
def _impl(inputs, attr, params):
- begin = params.pop(_get_name_hint(inputs[1])).asnumpy().tolist()
- size = params.pop(_get_name_hint(inputs[2])).asnumpy().tolist()
+ begin = _get_list_param(params, inputs[1])
+ size = _get_list_param(params, inputs[2])
data_shape = attr['_input_shapes'][inputs[0]]
data_dim = len(data_shape)
end = size
def _reshape():
def _impl(inputs, attr, params):
+ pop_node = inputs.pop(1)
try:
- pop_node = inputs[1]
- shape_arg = params.pop(pop_node.name_hint)
- inputs.pop(1)
-
- return AttrCvt(
- op_name="reshape",
- extras={'newshape':tuple(shape_arg.asnumpy())},
- ignores=['Tshape'])(inputs, attr)
+ shape_arg = _get_tuple_param(params, pop_node)
except AttributeError:
# Shape operator is already pruned, hence
# try to infer shape by precompute prune if possible.
- params_new = _infer_value(inputs[1], params)
- inputs.pop(1)
- return AttrCvt(
- op_name="reshape",
- extras={'newshape':tuple(params_new.asnumpy().astype('int64').flatten())},
- ignores=['Tshape'])(inputs, attr)
+ params_new = _infer_value(pop_node, params)
+ shape_arg = tuple(params_new.asnumpy().astype('int64').flatten())
+ return AttrCvt(
+ op_name="reshape",
+ extras={'newshape': shape_arg},
+ ignores=['Tshape'])(inputs, attr)
return _impl
if -1 in output_shape:
output_shape = _infer_value(inputs[0], params).asnumpy().reshape([-1]).tolist()
- fill_arg = params.pop(inputs.pop(1).name_hint)
- return _op.full(tvm.relay.const(fill_arg.asnumpy()[0], attr['T'].name),
- output_shape, attr['T'].name)
+ fill_arg = _get_num_param(params, inputs.pop(1))
+ dtype = attr['T'].name
+ return _op.full(tvm.relay.const(fill_arg, dtype),
+ output_shape, dtype)
return _impl
def _lrn():
def _sum():
def _impl(inputs, attr, params):
- axis = params.pop(inputs[1].name_hint).asnumpy()
- # convert to tuple for preventing invalid parameter format error
- axis = tuple(axis)
+ axis = _get_tuple_param(params, inputs[1])
return AttrCvt(
op_name='sum',
extras={'axis': axis},
def _gather():
"GatherV2, Gather"
def _impl(inputs, attr, params):
-
- axis = 0
if len(inputs) > 2:
- axis = params[inputs.pop(2).name_hint].asnumpy()[0]
- new_input = []
- new_input.append(inputs.pop(0))
- new_input.append(inputs.pop(0))
+ axis = _get_num_param(params, inputs.pop(2))
+ else:
+ axis = 0
+ new_input = inputs[0:2]
return AttrCvt(op_name="take",
extras={'axis': tvm.const(axis, 'int32')},
- ignores=['Tindices', 'Tparams', 'validate_indices', \
+ ignores=['Tindices', 'Tparams', 'validate_indices',
'Taxis', '_class'])(new_input, attr)
return _impl
-def _infer_out_shapes(inputs, params):
- """A method to get the output shape of an intermediate node in the relay graph."""
- out_type = ir_pass.infer_type(inputs)
- out_shapes = [get_const_tuple(out_type.checked_type.shape)]
- return out_shapes
-
def _stridedSlice():
def _impl(inputs, attr, params):
"""Strided Slice.
Tensorflow mask validation: https://github.com/tensorflow/tensorflow/blob/master/
tensorflow/core/util/strided_slice_op.cc#L147-L368
"""
- begin = params.pop(inputs[1].name_hint).asnumpy().tolist()
- end = params.pop(inputs[2].name_hint).asnumpy().tolist()
- stride = params.pop(inputs[3].name_hint).asnumpy().tolist()
+ begin = _get_list_param(params, inputs[1])
+ end = _get_list_param(params, inputs[2])
+ stride = _get_list_param(params, inputs[3])
begin_mask = int(attr.get('begin_mask', 0))
end_mask = int(attr.get('end_mask', 0))
ellipsis_mask = int(attr.get('ellipsis_mask', 0))
if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask:
begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask)
out = _op.strided_slice(inputs[0], begin=begin, end=end, strides=stride)
- out_shape = _infer_out_shapes(out, params)[0]
+ out_shape = _infer_shape(out, params)
if not fshape_indices:
fshape_indices = range(len(out_shape))
def _pad(name):
def _impl(inputs, attr, params):
- padlist_key = inputs[1].name_hint
- if padlist_key in params:
- padlist = params.pop(padlist_key).asnumpy()
- else:
- raise tvm.error.OpAttributeRequired(
- 'Attribute {} not found in operator Pad.'.format(padlist_key))
- paddings = tuple([tuple(l) for l in padlist])
+ padlist = _get_param(params, inputs[1])
+ paddings = tuple(tuple(l) for l in padlist)
attr['pad_width'] = paddings
attr['pad_value'] = 0
new_inputs = [inputs[0]]
if name == 'PadV2':
- constant_values = params.pop(inputs[2].name_hint).asnumpy()
- attr['pad_value'] = constant_values[0]
+ constant_values = _get_num_param(params, inputs[2])
+ attr['pad_value'] = constant_values
return AttrCvt(
op_name='pad',
ignores=['Tpaddings'],)(new_inputs, attr)
def _impl(inputs, attr, params):
# If perm is not specified, axes is left empty,
# otherwise its value is get from params
- param_name = _get_name_hint(inputs[1])
- if param_name in params:
- axes = tuple(params.get(param_name).asnumpy())
- else:
+ try:
+ axes = _get_list_param(params, inputs[1])
+ except (IndexError, KeyError):
axes = None
return _op.transpose(inputs[0], axes=axes)
return _impl
def _reverse_v2():
def _impl(inputs, attr, params):
- axis = params.pop(inputs[1].name_hint).asnumpy()[0]
+ axis = _get_num_param(params, inputs[1])
return AttrCvt(
op_name="reverse",
ignores=['Tidx'],
def _range():
def _impl(inputs, attr, params):
- start = params.pop(inputs[0].name_hint).asnumpy()[0]
- limit = params.pop(inputs[1].name_hint).asnumpy()[0]
- delta = params.pop(inputs[2].name_hint).asnumpy()[0]
+ start = _get_num_param(params, inputs[0])
+ limit = _get_num_param(params, inputs[1])
+ delta = _get_num_param(params, inputs[2])
name = attr["_node_name"]
params[name] = tvm.nd.array([start, limit, delta])
def _elu():
def _impl(inputs, attr, params):
- alpha = tvm.relay.const(-1.0, attr['T'].name)
- return alpha * _op.nn.relu(tvm.relay.const(1, attr['T'].name) \
+ dtype = attr['T'].name
+ alpha = tvm.relay.const(-1.0, dtype)
+ return alpha * _op.nn.relu(tvm.relay.const(1, dtype) \
- _op.exp(inputs[0])) + _op.nn.relu(inputs[0])
return _impl
def _selu():
def _impl(inputs, attr, params):
- alpha = tvm.relay.const(-1.6732632423543772848170429916717, attr['T'].name)
- gamma = tvm.relay.const(1.0507009873554804934193349852946, attr['T'].name)
- return gamma * (alpha * _op.nn.relu(tvm.relay.const(1, attr['T'].name) \
+ dtype = attr['T'].name
+ alpha = tvm.relay.const(-1.6732632423543772848170429916717, dtype)
+ gamma = tvm.relay.const(1.0507009873554804934193349852946, dtype)
+ return gamma * (alpha * _op.nn.relu(tvm.relay.const(1, dtype) \
- _op.exp(inputs[0])) + _op.nn.relu(inputs[0]))
return _impl
def _mean():
def _impl(inputs, attr, params):
- axis = params.pop(inputs[1].name_hint)
+ axis = _get_tuple_param(params, inputs[1])
return AttrCvt(op_name="mean", ignores=['Tdim', 'Tidx'],
transforms={'keep_dims': 'keepdims'},
- extras={'axis': tuple(axis.asnumpy())})([inputs[0]], attr)
+ extras={'axis': axis})([inputs[0]], attr)
return _impl
def _broadcast(name):
if has_size_vector:
input_node_index = 0
input_axis_index = 2
- size_splits_input_name = _get_name_hint(inputs[1])
- size_splits = params[size_splits_input_name].asnumpy()
+ size_splits = _get_param(params, inputs[1])
section_beginnings = np.cumsum(size_splits)[:-1]
indices_or_sections = tuple(section_beginnings)
else:
input_axis_index = 0
indices_or_sections = attr['num_split']
input_node = inputs[input_node_index]
- axis_input_name = _get_name_hint(inputs[input_axis_index])
- axis_input_value = params[axis_input_name].asnumpy()[0]
+ axis_input_value = _get_num_param(params, inputs[input_axis_index])
except (IndexError, KeyError):
raise TypeError( \
"Unsupported argument for split: `axis` and `num_or_size_splits` " \
def _impl(inputs, attr, params):
input_node = inputs[0]
input_shape = attr['_input_shapes'][input_node]
- block_shape = params.pop(inputs[1].name_hint).asnumpy().tolist()
- paddings = params.pop(inputs[2].name_hint).asnumpy().tolist()
+ block_shape = _get_list_param(params, inputs[1])
+ paddings = _get_list_param(params, inputs[2])
N = len(input_shape)
M = len(block_shape)
batch = input_shape[0]
axes = [2 * i + 2 for i in range(M)] + [0] + [2 * i + 1 for i in range(M)] + \
list(range(1 + 2 * M, 1 + 2 * M + remaining_shape_length))
permuted_reshaped_padded = tvm.relay.transpose(reshaped_padded, axes=axes)
- permuted_reshaped_padded_shape = _infer_out_shapes(permuted_reshaped_padded, params)[0]
+ permuted_reshaped_padded_shape = _infer_shape(permuted_reshaped_padded, params)
# Reshape permuted_reshaped_padded to flatten block_shape into the batch dimension,
# producing an output tensor of shape:
# [batch * prod(block_shape)] + [padded_shape[1] / block_shape[0], ...,
def _impl(inputs, attr, params):
input_node = inputs[0]
input_shape = attr['_input_shapes'][input_node]
- block_shape = params.pop(inputs[1].name_hint).asnumpy().tolist()
- crops = params.pop(inputs[2].name_hint).asnumpy().tolist()
+ block_shape = _get_list_param(params, inputs[1])
+ crops = _get_list_param(params, inputs[2])
M = len(block_shape)
batch = input_shape[0]
# From https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d:
# [batch / prod(block_shape), input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1],
# ..., input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1],
# input_shape[M+1], ..., input_shape[N-1]]
- reshaped_permuted_shape = _infer_out_shapes(reshaped_permuted, params)[0]
+ reshaped_permuted_shape = _infer_shape(reshaped_permuted, params)
cropped = reshaped_permuted
for axis in range(1, M+1):
crop = crops[axis - 1]
# Infer shapes even without specifying "add_shapes=True"
if output_shapes == [None]:
- out_shapes = []
- for node_item in self._nodes[node.name]:
- out_type = ir_pass.infer_type(node_item)
- out_shapes.append(get_const_tuple(out_type.checked_type.shape))
+ out_shapes = [_infer_shape(node_item) for node_item in self._nodes[node.name]]
self._output_shapes[node.name] = out_shapes
if self._output_shapes[node.name] and shape and node.name in shape:
assert self._output_shapes[node.name] == list(shape[node.name])
- # Infer shapes if passed explicitely
+ # Infer shapes if passed explicitly
node_output = self._nodes[node.name]
if shape and (not self._output_shapes[node.name][0]
or -1 in self._output_shapes[node.name][0]):
- out_shapes = []
- for node_item in node_output:
- out_type = ir_pass.infer_type(node_item)
- out_shapes.append(get_const_tuple(out_type.checked_type.shape))
+ out_shapes = [_infer_shape(node_item) for node_item in node_output]
self._output_shapes[node.name] = out_shapes
out = []
layout = None
if target == "cuda":
layout = "NCHW"
- target_host = 'llvm'
-
- if isinstance(input_data, list):
- shape_dict = {}
- dtype_dict = {}
- for i, e in enumerate(input_node):
- shape_dict[e] = input_data[i].shape
- dtype_dict[e] = input_data[i].dtype
- else:
- shape_dict = {input_node: input_data.shape}
- dtype_dict = {input_node: input_data.dtype}
+ target_host = None
+
+ shape_dict = {e: i.shape for e, i in zip(input_node, input_data)}
sym, params = relay.frontend.from_tensorflow(graph_def,
layout=layout,
shape=shape_dict,
outputs=out_names)
with relay.build_config(opt_level=opt_level):
- graph, lib, params = relay.build(sym, target, params=params)
+ graph, lib, params = relay.build(sym, target, target_host, params)
ctx = tvm.context(target, 0)
from tvm.contrib import graph_runtime
m = graph_runtime.create(graph, lib, ctx)
# set inputs
- for i, e in enumerate(input_node):
- m.set_input(e, tvm.nd.array(input_data[i].astype(input_data[i].dtype)))
+ for e, i in zip(input_node, input_data):
+ m.set_input(e, tvm.nd.array(i))
m.set_input(**params)
# execute
# get outputs
assert out_names is None or num_output == len(out_names), (
"out_names: {} num_output: {}".format(out_names, num_output))
- tvm_output_list = []
- for i in range(0, num_output):
- tvm_output = m.get_output(i)
- tvm_output_list.append(tvm_output.asnumpy())
+ tvm_output_list = [m.get_output(i).asnumpy() for i in range(num_output)]
return tvm_output_list
def run_tf_graph(sess, input_data, input_node, output_node):
input_node = convert_to_list(input_node)
output_node = convert_to_list(output_node)
- tensor = [0] * len(output_node)
- for i in range(len(output_node)):
- tensor[i] = sess.graph.get_tensor_by_name(output_node[i])
+ tensor = [sess.graph.get_tensor_by_name(output_name) for output_name in output_node]
- input_dict = {}
- for i, e in enumerate(input_node):
- input_dict[e] = input_data[i]
+ input_dict = {e: input_data[i] for i, e in enumerate(input_node)}
output_data = sess.run(tensor, input_dict)
return output_data
def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
no_gpu=False, opt_level=3):
"""Generic function to generate and compare tensorflow and TVM output"""
+ def name_without_num(name):
+ return name.split(':')[0] if ":" in name else name
out_name = convert_to_list(out_name)
- out_node = [0]*len(out_name)
- for i in range(len(out_name)):
- out_node[i] = out_name[i].split(':')[0] if ":" in out_name[i] else out_name[i]
+ out_node = [name_without_num(name) for name in out_name]
in_data = convert_to_list(in_data)
in_name = convert_to_list(in_name)
- in_node = [0]*len(in_name)
- for i in range(len(in_name)):
- in_node[i] = in_name[i].split(':')[0] if ":" in in_name[i] else in_name[i]
+ in_node = [name_without_num(name) for name in in_name]
with tf.Session() as sess:
if init_global_variables:
sess.run(variables.global_variables_initializer())
_test_variable(np.random.uniform(size=(32, 100)).astype('float32'))
+#######################################################################
+# MatMul
+# ------
+
+def _test_matmul(i, j, k, dtype, outer=None):
+ """ One iteration of matmul """
+
+ A_shape_init = [i, j]
+ B_shape_init = [j, k]
+
+ for transpose_a in [False, True]:
+ for transpose_b in [False, True]:
+ outer = outer or []
+ A_shape = outer + (A_shape_init[::-1] if transpose_a else A_shape_init)
+ B_shape = outer + (B_shape_init[::-1] if transpose_b else B_shape_init)
+
+ with tf.Graph().as_default():
+ A = tf.placeholder(shape=A_shape, dtype=dtype, name='A')
+ B = tf.placeholder(shape=B_shape, dtype=dtype, name='B')
+ result = tf.matmul(A, B, transpose_a=transpose_a, transpose_b=transpose_b)
+
+ A_np = np.random.uniform(high=5.0, size=A_shape).astype(dtype)
+ B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype)
+ compare_tf_with_tvm([A_np, B_np], [A.name, B.name], result.name)
+
+def test_forward_matmul():
+ """ Matmul op test"""
+ _test_matmul(1, 3, 6, 'int32')
+ _test_matmul(5, 3, 1, 'float64')
+ # TODO non-empty outer requires BatchMatMul (BatchMatMulV2 for some cases?) support
+
+
#######################################################################
# StridedSlice
# ------------
test_forward_rel_ops()
test_forward_logical()
test_where()
+
+ test_forward_matmul()
+ # TODO missing tests: rank, range
\ No newline at end of file