[Relay][Frontend] Support tf.where (#2936)
authorYong Wu <55wuyong@163.com>
Wed, 3 Apr 2019 11:10:10 +0000 (04:10 -0700)
committerSiva <sivar.b@huawei.com>
Wed, 3 Apr 2019 11:10:10 +0000 (16:40 +0530)
* [Relay][Frontend] Support tf.where

* fix comments

python/tvm/relay/frontend/tensorflow.py
tests/python/frontend/tensorflow/test_forward.py

index 37f2c23..0c5b40c 100644 (file)
@@ -683,10 +683,10 @@ def _gather():
         new_input = []
         new_input.append(inputs.pop(0))
         new_input.append(inputs.pop(0))
-        return  AttrCvt(op_name="take",
-                        extras={'axis': tvm.const(axis, 'int32')},
-                        ignores=['Tindices', 'Tparams', 'validate_indices', \
-                                 'Taxis', '_class'])(new_input, attr)
+        return AttrCvt(op_name="take",
+                       extras={'axis': tvm.const(axis, 'int32')},
+                       ignores=['Tindices', 'Tparams', 'validate_indices', \
+                                'Taxis', '_class'])(new_input, attr)
     return _impl
 
 def _infer_out_shapes(inputs, params):
@@ -818,7 +818,6 @@ def _pad(name):
             ignores=['Tpaddings'],)(new_inputs, attr)
     return _impl
 
-
 def _transpose():
     def _impl(inputs, attr, params):
         # If perm is not specified, axes is left empty,
@@ -831,6 +830,11 @@ def _transpose():
         return _op.transpose(inputs[0], axes=axes)
     return _impl
 
+def _where():
+    def _impl(inputs, attr, params):
+        return AttrCvt(op_name="where")(inputs, attr)
+    return _impl
+
 def _rank():
     def _impl(inputs, attr, params):
         input_shape = attr['_input_shapes'][inputs[0]]
@@ -1015,6 +1019,7 @@ _convert_map = {
     'DepthwiseConv2dNative'             : _conv('depthwise'),
     'Shape'                             : _shape(),
     'Sigmoid'                           : AttrCvt('sigmoid'),
+    'Select'                            : _where(),
     'Fill'                              : _fill(),
     'GatherV2'                          : _gather(),
     'Gather'                            : _gather(),
index 87fe53e..9d26280 100644 (file)
@@ -108,7 +108,6 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
     in_node = [0]*len(in_name)
     for i in range(len(in_name)):
         in_node[i] = in_name[i].split(':')[0] if ":" in in_name[i] else in_name[i]
-
     with tf.Session() as sess:
         if init_global_variables:
             sess.run(variables.global_variables_initializer())
@@ -483,7 +482,7 @@ def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype):
     in_data = tf.placeholder(dtype, ip_shape, name="in_data")
     indices = tf.placeholder("int32", indice_shape, name="indices")
     tf.gather(in_data, indices, axis=axis)
-    np_data = np.random.uniform(size=ip_shape).astype(dtype)
+    np_data = np.random.uniform(1, 10, size=ip_shape).astype(dtype)
 
     def _fill_indices(indice_value):
         indices = np.array(ip_shape, dtype=dtype)
@@ -500,14 +499,14 @@ def test_forward_gather():
     '''test GatherV2 layer'''
     _test_gather((4,), (1,), 1, 0, 'int32')
     _test_gather((4,), (1,), 1, 0, 'float32')
-    _test_gather((1,4), (1,), [0], 0, 'int32')
-    _test_gather((4,), (1,2,2), [[[1,0],[0,1]]], 0, 'float32')
-    _test_gather((2,2), (1,2,2), [[[1,0],[0,1]]], 0, 'int32')
-    _test_gather((2,2), (1,2,2), [[[1,0],[0,1]]], 1, 'int32')
-    _test_gather((2,2), (1,2,2), [[[1,0],[0,1]]], 0, 'float32')
-    _test_gather((3,3,3), (1,1,2), [[[1,0]]], 0, 'int32')
-    _test_gather((3,3,3), (1,1,2), [[[1,0]]], 2, 'int32')
-    _test_gather((4,3,5,6), (1,4), [[2,1,0,0]], 0, 'float32')
+    _test_gather((1, 4), (1,), [0], 0, 'int32')
+    _test_gather((4,), (1, 2, 2), [[[1, 0],[0, 1]]], 0, 'float32')
+    _test_gather((2, 2), (1, 2, 2), [[[1, 0],[0, 1]]], 0, 'int32')
+    _test_gather((2, 2), (1, 2, 2), [[[1, 0],[0, 1]]], 1, 'int32')
+    _test_gather((2, 2), (1, 2, 2), [[[1, 0],[0, 1]]], 0, 'float32')
+    _test_gather((3, 3, 3), (1, 1, 2), [[[1, 0]]], 0, 'int32')
+    _test_gather((3, 3, 3), (1, 1, 2), [[[1, 0]]], 2, 'int32')
+    _test_gather((4, 3, 5, 6), (1, 4), [[2, 1, 0, 0]], 0, 'float32')
 
 
 def _test_gather_v1(ip_shape, indice_shape, indice_value, dtype):
@@ -620,10 +619,10 @@ def _test_unstack(ip_shape, axis, dtype):
 def test_forward_unstack():
     '''test unstack layer'''
     _test_unstack((6,), 0, 'int32')
-    _test_unstack((2,6), 1, 'float64')
+    _test_unstack((2, 6), 1, 'float64')
     # negative axis
-    _test_unstack((1,4), -1, 'int32')
-    _test_unstack((3,6,4), -2, 'float32')
+    _test_unstack((1, 4), -1, 'int32')
+    _test_unstack((3, 6, 4), -2, 'float32')
 
 
 #######################################################################
@@ -864,6 +863,22 @@ def test_forward_logical():
 
 
 #######################################################################
+# Where, Select
+# -------------
+def test_where():
+    ''' Where: return elements depending on conditions'''
+    with tf.Graph().as_default():
+        with tf.Session() as sess:
+            input1 = tf.placeholder(tf.int32, shape=[1, 4, 4, 3], name='input1')
+            input2 = tf.placeholder(tf.int32, shape=[1, 4, 4, 3], name='input2')
+            mask = input1 > input2
+            tf.where(mask, input1 + 1, input2 * 2)
+            in_data1 = np.random.uniform(0, 10, size=(1, 4, 4, 3)).astype("uint32")
+            in_data2 = np.random.uniform(0, 10, size=(1, 4, 4, 3)).astype("uint32")
+            compare_tf_with_tvm([in_data1, in_data2], ['input1:0', 'input2:0'], 'Select:0')
+
+
+#######################################################################
 # Inception V3
 # ------------
 def test_forward_inception_v3():
@@ -1299,3 +1314,4 @@ if __name__ == '__main__':
     # Relational ops
     test_forward_rel_ops()
     test_forward_logical()
+    test_where()