__all__ = ['from_tensorflow']
+def _infer_value(input_val, params):
+ from tvm.contrib import graph_runtime
+ # Check that all free variables have associated parameters.
+ assert all(var.name_hint in params.keys() for var in ir_pass.free_vars(
+ input_val)), "All inputs to infer must be available in params."
+ func = _expr.Function(ir_pass.free_vars(input_val), input_val)
+ with tvm.relay.build_config(opt_level=0):
+ graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
+ ctx = tvm.context("llvm", 0)
+ m = graph_runtime.create(graph, lib, ctx)
+ m.set_input(**params)
+ m.run()
+ return m.get_output(0)
+
def _get_relay_op(op_name):
try:
op = getattr(_op, op_name)
def _resize_bilinear():
def _impl(inputs, attr, params):
- attr['size'] = attr['_output_shapes'][0][1:3]
+ size = attr['_output_shapes'][0][1:3]
+ # Important that the size is defined. If an axis is not, we need to infer what
+ # the shape should be.
+ if -1 in size:
+ size = _infer_value(inputs[1], params).asnumpy().reshape([-1]).tolist()
+ attr['size'] = size
inputs.pop(1)
# NHWC
attr['layout'] = 'NHWC'
except AttributeError:
# Shape operator is already pruned, hence
# try to infer shape by precompute prune if possible.
- func = _expr.Function(ir_pass.free_vars(inputs[1]), inputs[1])
- with tvm.relay.build_config(opt_level=0):
- graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
- ctx = tvm.context("llvm", 0)
- from tvm.contrib import graph_runtime
- m = graph_runtime.create(graph, lib, ctx)
- m.set_input(**params)
- m.run()
- params_new = m.get_output(0)
+ params_new = _infer_value(inputs[1], params)
inputs.pop(1)
return AttrCvt(
op_name="reshape",
ignores=['Tshape'])(inputs, attr)
return _impl
+
+def _depth_to_space():
+ def _impl(inputs, attr, params):
+ # Need to handle data layouts differently.
+ input_shape = attr['_input_shapes'][inputs[0]]
+ block_size = int(attr['block_size'])
+ if attr['data_format'].decode("utf-8") == 'NHWC':
+ in_n, in_h, in_w, in_c = input_shape
+ new_c = int(in_c / (block_size * block_size))
+
+ # First expand input to larger dimension.
+ expanded = _op.reshape(
+ inputs[0], newshape=(in_n, in_h, in_w, block_size, block_size, new_c))
+ # Now reorder to expand spatial blocks.
+ transposed = _op.transpose(expanded, axes=(0, 1, 3, 2, 4, 5))
+ # Finally reshape to proper output.
+ new_h = in_h * block_size
+ new_w = in_w * block_size
+ newshape = (in_n, new_h, new_w, new_c)
+
+ else: # Handle NCHW layout
+ in_n, in_c, in_h, in_w = input_shape
+ new_c = int(in_c / (block_size * block_size))
+
+ expanded = _op.reshape(
+ inputs[0], newshape=(in_n, block_size, block_size, new_c, in_h, in_w))
+ transposed = _op.transpose(expanded, axes=(0, 3, 4, 1, 5, 2))
+ new_h = in_h * block_size
+ new_w = in_w * block_size
+ newshape = (in_n, new_c, new_h, new_w)
+
+ return AttrCvt(
+ op_name="reshape",
+ extras={'newshape': newshape},
+ ignores=['data_format', 'block_size'])([transposed], attr)
+
+ return _impl
+
+
def _bias_add():
def _impl(inputs, attr, params):
- return _op.add(inputs[0], inputs[1])
+ # Must expand for proper broadcasting in NCHW.
+ if attr['data_format'].decode("utf-8") == 'NCHW':
+ bias = _op.reshape(inputs[1], newshape=(1, -1, 1, 1))
+ else:
+ bias = inputs[1]
+ return _op.add(inputs[0], bias)
+ return _impl
+
+def _broadcast_to():
+ def _impl(inputs, attr, params):
+ if isinstance(inputs[1], _expr.Var):
+ shape = params[inputs[1].name_hint]
+ else:
+ shape = _infer_value(inputs[1], params)
+ shape = list(shape.asnumpy().reshape([-1]))
+ return _op.broadcast_to(inputs[0], shape)
return _impl
def _squeeze():
def _fill():
def _impl(inputs, attr, params):
+ output_shape = attr['_output_shapes'][0]
+ # Output shape must be defined to avoid errors. If any axis is not, we must
+ # try to compute its shape.
+ if -1 in output_shape:
+ output_shape = _infer_value(inputs[0], params).asnumpy().reshape([-1]).tolist()
+
fill_arg = params.pop(inputs.pop(1).name_hint)
return _op.full(tvm.relay.const(fill_arg.asnumpy()[0], attr['T'].name),
- attr['_output_shapes'][0], attr['T'].name)
+ output_shape, attr['T'].name)
return _impl
def _lrn():
'BatchNormWithGlobalNormalization' : _batch_norm(),
'BatchToSpaceND' : _batch_to_space_nd(),
'BiasAdd' : _bias_add(),
+ 'BroadcastTo' : _broadcast_to(),
'Cast' : _cast(),
'Ceil' : AttrCvt('ceil'),
'CheckNumerics' : _check_numerics(),
'Conv2D' : _conv('conv'),
'DecodeJpeg' : _decode_image(),
'DepthwiseConv2dNative' : _conv('depthwise'),
+ 'DepthToSpace' : _depth_to_space(),
'Equal' : _broadcast('equal'),
'Elu' : _elu(),
'Exp' : AttrCvt('exp'),
'Prod' : _prod(),
'Range' : _range(),
'Rank' : _rank(),
- 'RealDiv' : _elemwise('div'),
+ 'RealDiv' : _elemwise('divide'),
'Relu' : AttrCvt('relu'),
'Relu6' : _relu6(),
'Reshape' : _reshape(),
'ResizeBilinear' : _resize_bilinear(),
+ 'ResizeBicubic' : _resize_bilinear(),
'ReverseV2' : _reverse_v2(),
'Round' : AttrCvt('round'),
'Rsqrt' : _rsqrt(),
x = [x]
return x
-def run_tvm_graph(graph_def, input_data, input_node, num_output=1, target='llvm', out_names=None):
+def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
+ target='llvm', out_names=None, opt_level=3):
""" Generic function to compile on relay and execute on tvm """
input_data = convert_to_list(input_data)
input_node = convert_to_list(input_node)
layout=layout,
shape=shape_dict,
outputs=out_names)
- with relay.build_config(opt_level=3):
+ with relay.build_config(opt_level=opt_level):
graph, lib, params = relay.build(sym, target, params=params)
ctx = tvm.context(target, 0)
# execute
m.run()
# get outputs
- assert out_names is None or num_output == len(out_names),"out_names: {} num_output: {}".format(
- out_names, num_output)
+ assert out_names is None or num_output == len(out_names), (
+ "out_names: {} num_output: {}".format(out_names, num_output))
tvm_output_list = []
for i in range(0, num_output):
tvm_output = m.get_output(i)
return output_data
-def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False, no_gpu=False):
+def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
+ no_gpu=False, opt_level=3):
"""Generic function to generate and compare tensorflow and TVM output"""
out_name = convert_to_list(out_name)
if no_gpu and device == 'cuda':
continue
- tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, target=device,
- out_names=out_name, num_output=len(out_name))
+ tvm_output = run_tvm_graph(final_graph_def, in_data, in_node,
+ target=device, out_names=out_name,
+ num_output=len(out_name), opt_level=opt_level)
# since the names from tensorflow and relay runs are not exactly same,
# first len(tf_output) will be compared
for i in range(len(tf_output)):
_test_reshape(np.arange(6), [-1])
#######################################################################
+# DepthToSpace
+# ------------
+
+def _test_depthtospace(data, block_size):
+ """ One iteration of depth_to_space operation with given data and block size """
+
+ with tf.Graph().as_default():
+ in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
+ array_ops.depth_to_space(in_data, block_size)
+
+ compare_tf_with_tvm(data, 'Placeholder:0', 'DepthToSpace:0')
+
+def test_forward_depthtospace():
+ _test_depthtospace(np.random.normal(size=[1, 32, 32, 4]), 2)
+ _test_depthtospace(np.random.normal(size=[1, 16, 8, 32]), 4)
+
+
#######################################################################
# Squeeze
# -------
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
- shape_data = constant_op.constant(shape_data, shape=shape_data.shape, dtype=shape_data.dtype)
+ shape_data = constant_op.constant(
+ shape_data, shape=shape_data.shape, dtype=shape_data.dtype)
tf.image.resize_bilinear(in_data, shape_data, align_corners=align_corners)
compare_tf_with_tvm(data, 'Placeholder:0', 'ResizeBilinear:0')
+def _test_resize_bilinear_from_tensor(in_shape, align_corners):
+ """ One iteration of resize bilinear with non-constant output shape, requires
+ value inference to get proper output shape."""
+
+ data = np.random.uniform(size=in_shape).astype('float32')
+
+ with tf.Graph().as_default():
+ in_data = array_ops.placeholder(
+ shape=[in_shape[0], in_shape[1], None, None], dtype=data.dtype)
+ to_shape = tf.shape(in_data)[2:]
+ tf.image.resize_bilinear(in_data, to_shape, align_corners=align_corners)
+
+ compare_tf_with_tvm(data, 'Placeholder:0', 'ResizeBilinear:0')
+
def test_forward_resize_bilinear():
""" Resize Bilinear """
_test_resize_bilinear((4, 16, 32, 32), [50, 50], False)
_test_resize_bilinear((6, 32, 64, 64), [20, 20], True)
+ _test_resize_bilinear_from_tensor((4, 16, 32, 32), False)
+ _test_resize_bilinear_from_tensor((6, 32, 50, 50), True)
+
+#######################################################################
+# BroadcastTo
+# -----------
+
+def _test_broadcast_to(in_shape, to_shape):
+ """ One iteration of broadcast_to"""
+
+ data = np.random.uniform(size=in_shape).astype('float32')
+ shape_data = np.array(to_shape).astype('int32')
+
+ with tf.Graph().as_default():
+ in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
+ shape_data = constant_op.constant(
+ shape_data, shape=shape_data.shape, dtype=shape_data.dtype)
+ tf.broadcast_to(in_data, shape_data)
+
+ compare_tf_with_tvm(data, 'Placeholder:0', 'BroadcastTo:0', opt_level=0)
+
+
+def _test_broadcast_to_from_tensor(in_shape):
+ """ One iteration of broadcast_to with unknown shape at graph build"""
+
+ data = np.random.uniform(size=in_shape).astype('float32')
+
+ with tf.Graph().as_default():
+ in_data = array_ops.placeholder(
+ shape=[None], dtype=data.dtype)
+
+ shape_data = tf.multiply(tf.shape(in_data), 32)
+ tf.broadcast_to(in_data, shape_data)
+
+ compare_tf_with_tvm(data, 'Placeholder:0', 'BroadcastTo:0')
+
+
+def test_forward_broadcast_to():
+ """ Resize Bilinear """
+
+ _test_broadcast_to((4, 1, 32, 32), [4, 8, 32, 32])
+ _test_broadcast_to((6, 32, 32, 1), [6, 32, 32, 16])
+ _test_broadcast_to_from_tensor((1))
+
+
+#######################################################################
+# Fill
+# ----
+
+def _test_fill(in_shape):
+ """ Use the fill op to create a tensor of ones with non-constant shape."""
+
+ with tf.Graph().as_default():
+ tf.ones(shape=in_shape, dtype='float32')
+ compare_tf_with_tvm(in_shape, [], 'ones:0', opt_level=1)
+
+def _test_fill_from_tensor(in_shape):
+ """ Use the fill op to create a tensor of ones with non-constant shape.
+ Some extra ops need to be added here to prevent the graph from
+ being fully constant and folded away."""
+
+ data = np.random.uniform(size=in_shape).astype('float32')
+
+ with tf.Graph().as_default():
+ in_data = array_ops.placeholder(
+ shape=[in_shape[0], in_shape[1], None, None], dtype=data.dtype)
+
+ x = tf.ones(shape=2*tf.shape(in_data), dtype=data.dtype)
+ y = tf.math.add(in_data, tf.reduce_mean(x), name='out1')
+ compare_tf_with_tvm(data, 'Placeholder:0', 'out1:0')
+
+def test_forward_fill():
+ """ Resize Bilinear """
+
+ _test_fill((32))
+ _test_fill((6, 32, 64, 64))
+ _test_fill_from_tensor((6, 32, 64, 64))
#######################################################################
# Crop to bounding box
# Transforms
test_forward_transpose()
test_forward_reshape()
+ test_forward_depthtospace()
test_forward_squeeze()
test_forward_pack()
test_forward_resize_bilinear()
+ test_forward_broadcast_to()
+ test_forward_fill()
test_forward_crop()
test_forward_pad()
test_forward_gather()