new_input = []
new_input.append(inputs.pop(0))
new_input.append(inputs.pop(0))
- return AttrCvt(op_name="take",
- extras={'axis': tvm.const(axis, 'int32')},
- ignores=['Tindices', 'Tparams', 'validate_indices', \
- 'Taxis', '_class'])(new_input, attr)
+ return AttrCvt(op_name="take",
+ extras={'axis': tvm.const(axis, 'int32')},
+ ignores=['Tindices', 'Tparams', 'validate_indices', \
+ 'Taxis', '_class'])(new_input, attr)
return _impl
def _infer_out_shapes(inputs, params):
ignores=['Tpaddings'],)(new_inputs, attr)
return _impl
-
def _transpose():
def _impl(inputs, attr, params):
# If perm is not specified, axes is left empty,
return _op.transpose(inputs[0], axes=axes)
return _impl
+def _where():
+ def _impl(inputs, attr, params):
+ return AttrCvt(op_name="where")(inputs, attr)
+ return _impl
+
def _rank():
def _impl(inputs, attr, params):
input_shape = attr['_input_shapes'][inputs[0]]
'DepthwiseConv2dNative' : _conv('depthwise'),
'Shape' : _shape(),
'Sigmoid' : AttrCvt('sigmoid'),
+ 'Select' : _where(),
'Fill' : _fill(),
'GatherV2' : _gather(),
'Gather' : _gather(),
in_node = [0]*len(in_name)
for i in range(len(in_name)):
in_node[i] = in_name[i].split(':')[0] if ":" in in_name[i] else in_name[i]
-
with tf.Session() as sess:
if init_global_variables:
sess.run(variables.global_variables_initializer())
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
indices = tf.placeholder("int32", indice_shape, name="indices")
tf.gather(in_data, indices, axis=axis)
- np_data = np.random.uniform(size=ip_shape).astype(dtype)
+ np_data = np.random.uniform(1, 10, size=ip_shape).astype(dtype)
def _fill_indices(indice_value):
indices = np.array(ip_shape, dtype=dtype)
'''test GatherV2 layer'''
_test_gather((4,), (1,), 1, 0, 'int32')
_test_gather((4,), (1,), 1, 0, 'float32')
- _test_gather((1,4), (1,), [0], 0, 'int32')
- _test_gather((4,), (1,2,2), [[[1,0],[0,1]]], 0, 'float32')
- _test_gather((2,2), (1,2,2), [[[1,0],[0,1]]], 0, 'int32')
- _test_gather((2,2), (1,2,2), [[[1,0],[0,1]]], 1, 'int32')
- _test_gather((2,2), (1,2,2), [[[1,0],[0,1]]], 0, 'float32')
- _test_gather((3,3,3), (1,1,2), [[[1,0]]], 0, 'int32')
- _test_gather((3,3,3), (1,1,2), [[[1,0]]], 2, 'int32')
- _test_gather((4,3,5,6), (1,4), [[2,1,0,0]], 0, 'float32')
+ _test_gather((1, 4), (1,), [0], 0, 'int32')
+ _test_gather((4,), (1, 2, 2), [[[1, 0],[0, 1]]], 0, 'float32')
+ _test_gather((2, 2), (1, 2, 2), [[[1, 0],[0, 1]]], 0, 'int32')
+ _test_gather((2, 2), (1, 2, 2), [[[1, 0],[0, 1]]], 1, 'int32')
+ _test_gather((2, 2), (1, 2, 2), [[[1, 0],[0, 1]]], 0, 'float32')
+ _test_gather((3, 3, 3), (1, 1, 2), [[[1, 0]]], 0, 'int32')
+ _test_gather((3, 3, 3), (1, 1, 2), [[[1, 0]]], 2, 'int32')
+ _test_gather((4, 3, 5, 6), (1, 4), [[2, 1, 0, 0]], 0, 'float32')
def _test_gather_v1(ip_shape, indice_shape, indice_value, dtype):
def test_forward_unstack():
'''test unstack layer'''
_test_unstack((6,), 0, 'int32')
- _test_unstack((2,6), 1, 'float64')
+ _test_unstack((2, 6), 1, 'float64')
# negative axis
- _test_unstack((1,4), -1, 'int32')
- _test_unstack((3,6,4), -2, 'float32')
+ _test_unstack((1, 4), -1, 'int32')
+ _test_unstack((3, 6, 4), -2, 'float32')
#######################################################################
#######################################################################
+# Where, Select
+# -------------
+def test_where():
+ ''' Where: return elements depending on conditions'''
+ with tf.Graph().as_default():
+ with tf.Session() as sess:
+ input1 = tf.placeholder(tf.int32, shape=[1, 4, 4, 3], name='input1')
+ input2 = tf.placeholder(tf.int32, shape=[1, 4, 4, 3], name='input2')
+ mask = input1 > input2
+ tf.where(mask, input1 + 1, input2 * 2)
+ in_data1 = np.random.uniform(0, 10, size=(1, 4, 4, 3)).astype("uint32")
+ in_data2 = np.random.uniform(0, 10, size=(1, 4, 4, 3)).astype("uint32")
+ compare_tf_with_tvm([in_data1, in_data2], ['input1:0', 'input2:0'], 'Select:0')
+
+
+#######################################################################
# Inception V3
# ------------
def test_forward_inception_v3():
# Relational ops
test_forward_rel_ops()
test_forward_logical()
+ test_where()